diff --git a/BRANCH_TODO b/BRANCH_TODO index 294bb42d5556..ca3888f391ce 100644 --- a/BRANCH_TODO +++ b/BRANCH_TODO @@ -1,4 +1,4 @@ - * go over the commented out tests in cancel.zig + * clean up the bitcasting of awaiter fn ptr * compile error for error: expected anyframe->T, found 'anyframe' * compile error for error: expected anyframe->T, found 'i32' * await of a non async function diff --git a/src/all_types.hpp b/src/all_types.hpp index e1fff953b469..a7fb542ad3da 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1556,6 +1556,7 @@ enum PanicMsgId { PanicMsgIdBadAwait, PanicMsgIdBadReturn, PanicMsgIdResumedAnAwaitingFn, + PanicMsgIdResumedACancelingFn, PanicMsgIdFrameTooSmall, PanicMsgIdResumedFnPendingAwait, @@ -3432,7 +3433,7 @@ struct IrInstructionErrorUnion { struct IrInstructionCancel { IrInstruction base; - IrInstruction *target; + IrInstruction *frame; }; struct IrInstructionAtomicRmw { diff --git a/src/analyze.cpp b/src/analyze.cpp index 764b28ed4587..cf71bd90f35b 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -3811,6 +3811,9 @@ static void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn) { } else if (fn->inferred_async_node->type == NodeTypeAwaitExpr) { add_error_note(g, msg, fn->inferred_async_node, buf_sprintf("await is a suspend point")); + } else if (fn->inferred_async_node->type == NodeTypeCancel) { + add_error_note(g, msg, fn->inferred_async_node, + buf_sprintf("cancel is a suspend point")); } else { zig_unreachable(); } diff --git a/src/codegen.cpp b/src/codegen.cpp index 7a27585e45a0..2a6c5f8b8f1a 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -911,11 +911,13 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) { case PanicMsgIdBadResume: return buf_create_from_str("resumed an async function which already returned"); case PanicMsgIdBadAwait: - return buf_create_from_str("async function awaited twice"); + return buf_create_from_str("async function awaited/canceled twice"); case PanicMsgIdBadReturn: return buf_create_from_str("async function returned twice"); case PanicMsgIdResumedAnAwaitingFn: return buf_create_from_str("awaiting function resumed"); + case PanicMsgIdResumedACancelingFn: + return buf_create_from_str("canceling function resumed"); case PanicMsgIdFrameTooSmall: return buf_create_from_str("frame too small"); case PanicMsgIdResumedFnPendingAwait: @@ -2189,12 +2191,12 @@ static void gen_assert_resume_id(CodeGen *g, IrInstruction *source_instr, Resume if (end_bb == nullptr) end_bb = LLVMAppendBasicBlock(g->cur_fn_val, "OkResume"); LLVMValueRef ok_bit; if (resume_id == ResumeIdAwaitEarlyReturn) { - LLVMValueRef last_value = LLVMBuildSub(g->builder, LLVMConstAllOnes(usize_type_ref), - LLVMConstInt(usize_type_ref, ResumeIdAwaitEarlyReturn, false), ""); + LLVMValueRef last_value = LLVMConstSub(LLVMConstAllOnes(usize_type_ref), + LLVMConstInt(usize_type_ref, ResumeIdAwaitEarlyReturn, false)); ok_bit = LLVMBuildICmp(g->builder, LLVMIntULT, LLVMGetParam(g->cur_fn_val, 1), last_value, ""); } else { - LLVMValueRef expected_value = LLVMBuildSub(g->builder, LLVMConstAllOnes(usize_type_ref), - LLVMConstInt(usize_type_ref, resume_id, false), ""); + LLVMValueRef expected_value = LLVMConstSub(LLVMConstAllOnes(usize_type_ref), + LLVMConstInt(usize_type_ref, resume_id, false)); ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, LLVMGetParam(g->cur_fn_val, 1), expected_value, ""); } LLVMBuildCondBr(g->builder, ok_bit, end_bb, bad_resume_block); @@ -2210,11 +2212,13 @@ static LLVMValueRef gen_resume(CodeGen *g, LLVMValueRef fn_val, LLVMValueRef tar { LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; if (fn_val == nullptr) { - if (g->anyframe_fn_type == nullptr) { - (void)get_llvm_type(g, get_any_frame_type(g, nullptr)); - } LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, coro_fn_ptr_index, ""); - fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, ""); + LLVMValueRef fn_val_typed = LLVMBuildLoad(g->builder, fn_ptr_ptr, ""); + LLVMValueRef as_int = LLVMBuildPtrToInt(g->builder, fn_val_typed, usize_type_ref, ""); + LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false); + LLVMValueRef mask_val = LLVMConstNot(one); + LLVMValueRef as_int_masked = LLVMBuildAnd(g->builder, as_int, mask_val, ""); + fn_val = LLVMBuildIntToPtr(g->builder, as_int_masked, LLVMTypeOf(fn_val_typed), ""); } if (arg_val == nullptr) { arg_val = LLVMBuildSub(g->builder, LLVMConstAllOnes(usize_type_ref), @@ -2226,6 +2230,17 @@ static LLVMValueRef gen_resume(CodeGen *g, LLVMValueRef fn_val, LLVMValueRef tar return ZigLLVMBuildCall(g->builder, fn_val, args, 2, LLVMFastCallConv, ZigLLVM_FnInlineAuto, ""); } +static LLVMBasicBlockRef gen_suspend_begin(CodeGen *g, const char *name_hint) { + LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; + LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, name_hint); + size_t new_block_index = g->cur_resume_block_count; + g->cur_resume_block_count += 1; + LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false); + LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, resume_bb); + LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr); + return resume_bb; +} + static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable, IrInstructionReturnBegin *instruction) { @@ -2245,12 +2260,7 @@ static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable, } // Prepare to be suspended. We might end up not having to suspend though. - LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, "ReturnResume"); - size_t new_block_index = g->cur_resume_block_count; - g->cur_resume_block_count += 1; - LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false); - LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, resume_bb); - LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr); + LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "ReturnResume"); LLVMValueRef zero = LLVMConstNull(usize_type_ref); LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref); @@ -2335,7 +2345,10 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrIns // We need to resume the caller by tail calling them. ZigType *any_frame_type = get_any_frame_type(g, ret_type); - LLVMValueRef their_frame_ptr = LLVMBuildIntToPtr(g->builder, g->cur_async_prev_val, + LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false); + LLVMValueRef mask_val = LLVMConstNot(one); + LLVMValueRef masked_prev_val = LLVMBuildAnd(g->builder, g->cur_async_prev_val, mask_val, ""); + LLVMValueRef their_frame_ptr = LLVMBuildIntToPtr(g->builder, masked_prev_val, get_llvm_type(g, any_frame_type), ""); LLVMValueRef call_inst = gen_resume(g, nullptr, their_frame_ptr, ResumeIdReturn, nullptr); ZigLLVMSetTailCall(call_inst); @@ -3945,13 +3958,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr } else if (callee_is_async) { ZigType *ptr_result_type = get_pointer_to_type(g, src_return_type, true); - LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(g->cur_fn_val, "CallResume"); - size_t new_block_index = g->cur_resume_block_count; - g->cur_resume_block_count += 1; - LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false); - LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, call_bb); - - LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr); + LLVMBasicBlockRef call_bb = gen_suspend_begin(g, "CallResume"); LLVMValueRef call_inst = gen_resume(g, fn_val, frame_result_loc, ResumeIdCall, nullptr); ZigLLVMSetTailCall(call_inst); @@ -4672,10 +4679,6 @@ static LLVMValueRef ir_render_error_return_trace(CodeGen *g, IrExecutable *execu return cur_err_ret_trace_val; } -static LLVMValueRef ir_render_cancel(CodeGen *g, IrExecutable *executable, IrInstructionCancel *instruction) { - zig_panic("TODO cancel"); -} - static LLVMAtomicOrdering to_LLVMAtomicOrdering(AtomicOrder atomic_order) { switch (atomic_order) { case AtomicOrderUnordered: return LLVMAtomicOrderingUnordered; @@ -5416,13 +5419,7 @@ static LLVMValueRef ir_render_assert_non_null(CodeGen *g, IrExecutable *executab static LLVMValueRef ir_render_suspend_begin(CodeGen *g, IrExecutable *executable, IrInstructionSuspendBegin *instruction) { - LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; - instruction->resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, "SuspendResume"); - size_t new_block_index = g->cur_resume_block_count; - g->cur_resume_block_count += 1; - LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false); - LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, instruction->resume_bb); - LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr); + instruction->resume_bb = gen_suspend_begin(g, "SuspendResume"); return nullptr; } @@ -5436,6 +5433,43 @@ static LLVMValueRef ir_render_suspend_finish(CodeGen *g, IrExecutable *executabl return nullptr; } +static LLVMValueRef ir_render_cancel(CodeGen *g, IrExecutable *executable, IrInstructionCancel *instruction) { + LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; + LLVMValueRef zero = LLVMConstNull(usize_type_ref); + LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref); + LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false); + + LLVMValueRef target_frame_ptr = ir_llvm_value(g, instruction->frame); + LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "CancelResume"); + + LLVMValueRef awaiter_val = LLVMBuildPtrToInt(g->builder, g->cur_frame_ptr, usize_type_ref, ""); + LLVMValueRef awaiter_ored_val = LLVMBuildOr(g->builder, awaiter_val, one, ""); + LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, coro_awaiter_index, ""); + + LLVMValueRef prev_val = LLVMBuildAtomicRMW(g->builder, LLVMAtomicRMWBinOpXchg, awaiter_ptr, awaiter_ored_val, + LLVMAtomicOrderingRelease, g->is_single_threaded); + + LLVMBasicBlockRef complete_suspend_block = LLVMAppendBasicBlock(g->cur_fn_val, "CancelSuspend"); + LLVMBasicBlockRef early_return_block = LLVMAppendBasicBlock(g->cur_fn_val, "EarlyReturn"); + + LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, prev_val, resume_bb, 2); + LLVMAddCase(switch_instr, zero, complete_suspend_block); + LLVMAddCase(switch_instr, all_ones, early_return_block); + + LLVMPositionBuilderAtEnd(g->builder, complete_suspend_block); + LLVMBuildRetVoid(g->builder); + + LLVMPositionBuilderAtEnd(g->builder, early_return_block); + LLVMValueRef call_inst = gen_resume(g, nullptr, target_frame_ptr, ResumeIdAwaitEarlyReturn, awaiter_ored_val); + ZigLLVMSetTailCall(call_inst); + LLVMBuildRetVoid(g->builder); + + LLVMPositionBuilderAtEnd(g->builder, resume_bb); + gen_assert_resume_id(g, &instruction->base, ResumeIdReturn, PanicMsgIdResumedACancelingFn, nullptr); + + return nullptr; +} + static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInstructionAwaitGen *instruction) { LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; LLVMValueRef zero = LLVMConstNull(usize_type_ref); @@ -5444,12 +5478,7 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst ZigType *ptr_result_type = get_pointer_to_type(g, result_type, true); // Prepare to be suspended - LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, "AwaitResume"); - size_t new_block_index = g->cur_resume_block_count; - g->cur_resume_block_count += 1; - LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false); - LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, resume_bb); - LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr); + LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "AwaitResume"); // At this point resuming the function will do the correct thing. // This code is as if it is running inside the suspend block. diff --git a/src/ir.cpp b/src/ir.cpp index 7cb868cab2ae..853cf4daa1b0 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -3271,6 +3271,16 @@ static IrInstruction *ir_build_suspend_finish(IrBuilder *irb, Scope *scope, AstN return &instruction->base; } +static IrInstruction *ir_build_cancel(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *frame) { + IrInstructionCancel *instruction = ir_build_instruction(irb, scope, source_node); + instruction->base.value.type = irb->codegen->builtin_types.entry_void; + instruction->frame = frame; + + ir_ref_instruction(frame, irb->current_basic_block); + + return &instruction->base; +} + static IrInstruction *ir_build_await_src(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *frame, ResultLoc *result_loc) { @@ -7820,11 +7830,26 @@ static IrInstruction *ir_gen_fn_proto(IrBuilder *irb, Scope *parent_scope, AstNo static IrInstruction *ir_gen_cancel(IrBuilder *irb, Scope *scope, AstNode *node) { assert(node->type == NodeTypeCancel); - IrInstruction *target_inst = ir_gen_node(irb, node->data.cancel_expr.expr, scope); - if (target_inst == irb->codegen->invalid_instruction) + ZigFn *fn_entry = exec_fn_entry(irb->exec); + if (!fn_entry) { + add_node_error(irb->codegen, node, buf_sprintf("cancel outside function definition")); + return irb->codegen->invalid_instruction; + } + ScopeSuspend *existing_suspend_scope = get_scope_suspend(scope); + if (existing_suspend_scope) { + if (!existing_suspend_scope->reported_err) { + ErrorMsg *msg = add_node_error(irb->codegen, node, buf_sprintf("cannot cancel inside suspend block")); + add_error_note(irb->codegen, msg, existing_suspend_scope->base.source_node, buf_sprintf("suspend block here")); + existing_suspend_scope->reported_err = true; + } + return irb->codegen->invalid_instruction; + } + + IrInstruction *operand = ir_gen_node(irb, node->data.cancel_expr.expr, scope); + if (operand == irb->codegen->invalid_instruction) return irb->codegen->invalid_instruction; - zig_panic("TODO ir_gen_cancel"); + return ir_build_cancel(irb, scope, node, operand); } static IrInstruction *ir_gen_resume(IrBuilder *irb, Scope *scope, AstNode *node) { @@ -23781,10 +23806,6 @@ static IrInstruction *ir_analyze_instruction_tag_type(IrAnalyze *ira, IrInstruct } } -static IrInstruction *ir_analyze_instruction_cancel(IrAnalyze *ira, IrInstructionCancel *instruction) { - zig_panic("TODO analyze cancel"); -} - static ZigType *ir_resolve_atomic_operand_type(IrAnalyze *ira, IrInstruction *op) { ZigType *operand_type = ir_resolve_type(ira, op); if (type_is_invalid(operand_type)) @@ -24474,6 +24495,26 @@ static IrInstruction *ir_analyze_instruction_suspend_finish(IrAnalyze *ira, return ir_build_suspend_finish(&ira->new_irb, instruction->base.scope, instruction->base.source_node, begin); } +static IrInstruction *ir_analyze_instruction_cancel(IrAnalyze *ira, IrInstructionCancel *instruction) { + IrInstruction *frame = instruction->frame->child; + if (type_is_invalid(frame->value.type)) + return ira->codegen->invalid_instruction; + + ZigType *any_frame_type = get_any_frame_type(ira->codegen, nullptr); + IrInstruction *casted_frame = ir_implicit_cast(ira, frame, any_frame_type); + if (type_is_invalid(casted_frame->value.type)) + return ira->codegen->invalid_instruction; + + ZigFn *fn_entry = exec_fn_entry(ira->new_irb.exec); + ir_assert(fn_entry != nullptr, &instruction->base); + + if (fn_entry->inferred_async_node == nullptr) { + fn_entry->inferred_async_node = instruction->base.source_node; + } + + return ir_build_cancel(&ira->new_irb, instruction->base.scope, instruction->base.source_node, casted_frame); +} + static IrInstruction *ir_analyze_instruction_await(IrAnalyze *ira, IrInstructionAwaitSrc *instruction) { IrInstruction *frame_ptr = instruction->frame->child; if (type_is_invalid(frame_ptr->value.type)) diff --git a/src/ir_print.cpp b/src/ir_print.cpp index c56a660e293f..0348cfc986a7 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -1396,7 +1396,7 @@ static void ir_print_error_union(IrPrint *irp, IrInstructionErrorUnion *instruct static void ir_print_cancel(IrPrint *irp, IrInstructionCancel *instruction) { fprintf(irp->f, "cancel "); - ir_print_other_instruction(irp, instruction->target); + ir_print_other_instruction(irp, instruction->frame); } static void ir_print_atomic_rmw(IrPrint *irp, IrInstructionAtomicRmw *instruction) { diff --git a/test/stage1/behavior/cancel.zig b/test/stage1/behavior/cancel.zig index cb8a0752799c..c8636212b04a 100644 --- a/test/stage1/behavior/cancel.zig +++ b/test/stage1/behavior/cancel.zig @@ -1,86 +1,94 @@ const std = @import("std"); +const expect = std.testing.expect; -//var defer_f1: bool = false; -//var defer_f2: bool = false; -//var defer_f3: bool = false; -// -//test "cancel forwards" { -// const p = async f1() catch unreachable; -// cancel p; -// std.testing.expect(defer_f1); -// std.testing.expect(defer_f2); -// std.testing.expect(defer_f3); -//} -// -//async fn f1() void { -// defer { -// defer_f1 = true; -// } -// await (async f2() catch unreachable); -//} -// -//async fn f2() void { -// defer { -// defer_f2 = true; -// } -// await (async f3() catch unreachable); -//} -// -//async fn f3() void { -// defer { -// defer_f3 = true; -// } -// suspend; -//} -// -//var defer_b1: bool = false; -//var defer_b2: bool = false; -//var defer_b3: bool = false; -//var defer_b4: bool = false; -// -//test "cancel backwards" { -// const p = async b1() catch unreachable; -// cancel p; -// std.testing.expect(defer_b1); -// std.testing.expect(defer_b2); -// std.testing.expect(defer_b3); -// std.testing.expect(defer_b4); -//} -// -//async fn b1() void { -// defer { -// defer_b1 = true; -// } -// await (async b2() catch unreachable); -//} -// -//var b4_handle: promise = undefined; -// -//async fn b2() void { -// const b3_handle = async b3() catch unreachable; -// resume b4_handle; -// cancel b4_handle; -// defer { -// defer_b2 = true; -// } -// const value = await b3_handle; -// @panic("unreachable"); -//} -// -//async fn b3() i32 { -// defer { -// defer_b3 = true; -// } -// await (async b4() catch unreachable); -// return 1234; -//} -// -//async fn b4() void { -// defer { -// defer_b4 = true; -// } -// suspend { -// b4_handle = @handle(); -// } -// suspend; -//} +var defer_f1: bool = false; +var defer_f2: bool = false; +var defer_f3: bool = false; +var f3_frame: anyframe = undefined; + +test "cancel forwards" { + _ = async atest1(); + resume f3_frame; +} + +fn atest1() void { + const p = async f1(); + cancel &p; + expect(defer_f1); + expect(defer_f2); + expect(defer_f3); +} + +async fn f1() void { + defer { + defer_f1 = true; + } + var f2_frame = async f2(); + await f2_frame; +} + +async fn f2() void { + defer { + defer_f2 = true; + } + f3(); +} + +async fn f3() void { + f3_frame = @frame(); + defer { + defer_f3 = true; + } + suspend; +} + +var defer_b1: bool = false; +var defer_b2: bool = false; +var defer_b3: bool = false; +var defer_b4: bool = false; + +test "cancel backwards" { + _ = async b1(); + resume b4_handle; + expect(defer_b1); + expect(defer_b2); + expect(defer_b3); + expect(defer_b4); +} + +async fn b1() void { + defer { + defer_b1 = true; + } + b2(); +} + +var b4_handle: anyframe = undefined; + +async fn b2() void { + const b3_handle = async b3(); + resume b4_handle; + defer { + defer_b2 = true; + } + const value = await b3_handle; + expect(value == 1234); +} + +async fn b3() i32 { + defer { + defer_b3 = true; + } + b4(); + return 1234; +} + +async fn b4() void { + defer { + defer_b4 = true; + } + suspend { + b4_handle = @frame(); + } + suspend; +}