Skip to content

Commit

Permalink
implement cancel
Browse files Browse the repository at this point in the history
all behavior tests passing in this branch
  • Loading branch information
andrewrk committed Aug 7, 2019
1 parent 1afbb53 commit 7e1fcb5
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 134 deletions.
2 changes: 1 addition & 1 deletion BRANCH_TODO
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
* go over the commented out tests in cancel.zig
* clean up the bitcasting of awaiter fn ptr
* compile error for error: expected anyframe->T, found 'anyframe'
* compile error for error: expected anyframe->T, found 'i32'
* await of a non async function
Expand Down
3 changes: 2 additions & 1 deletion src/all_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,7 @@ enum PanicMsgId {
PanicMsgIdBadAwait,
PanicMsgIdBadReturn,
PanicMsgIdResumedAnAwaitingFn,
PanicMsgIdResumedACancelingFn,
PanicMsgIdFrameTooSmall,
PanicMsgIdResumedFnPendingAwait,

Expand Down Expand Up @@ -3432,7 +3433,7 @@ struct IrInstructionErrorUnion {
struct IrInstructionCancel {
IrInstruction base;

IrInstruction *target;
IrInstruction *frame;
};

struct IrInstructionAtomicRmw {
Expand Down
3 changes: 3 additions & 0 deletions src/analyze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3811,6 +3811,9 @@ static void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn) {
} else if (fn->inferred_async_node->type == NodeTypeAwaitExpr) {
add_error_note(g, msg, fn->inferred_async_node,
buf_sprintf("await is a suspend point"));
} else if (fn->inferred_async_node->type == NodeTypeCancel) {
add_error_note(g, msg, fn->inferred_async_node,
buf_sprintf("cancel is a suspend point"));
} else {
zig_unreachable();
}
Expand Down
109 changes: 69 additions & 40 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,11 +911,13 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
case PanicMsgIdBadResume:
return buf_create_from_str("resumed an async function which already returned");
case PanicMsgIdBadAwait:
return buf_create_from_str("async function awaited twice");
return buf_create_from_str("async function awaited/canceled twice");
case PanicMsgIdBadReturn:
return buf_create_from_str("async function returned twice");
case PanicMsgIdResumedAnAwaitingFn:
return buf_create_from_str("awaiting function resumed");
case PanicMsgIdResumedACancelingFn:
return buf_create_from_str("canceling function resumed");
case PanicMsgIdFrameTooSmall:
return buf_create_from_str("frame too small");
case PanicMsgIdResumedFnPendingAwait:
Expand Down Expand Up @@ -2189,12 +2191,12 @@ static void gen_assert_resume_id(CodeGen *g, IrInstruction *source_instr, Resume
if (end_bb == nullptr) end_bb = LLVMAppendBasicBlock(g->cur_fn_val, "OkResume");
LLVMValueRef ok_bit;
if (resume_id == ResumeIdAwaitEarlyReturn) {
LLVMValueRef last_value = LLVMBuildSub(g->builder, LLVMConstAllOnes(usize_type_ref),
LLVMConstInt(usize_type_ref, ResumeIdAwaitEarlyReturn, false), "");
LLVMValueRef last_value = LLVMConstSub(LLVMConstAllOnes(usize_type_ref),
LLVMConstInt(usize_type_ref, ResumeIdAwaitEarlyReturn, false));
ok_bit = LLVMBuildICmp(g->builder, LLVMIntULT, LLVMGetParam(g->cur_fn_val, 1), last_value, "");
} else {
LLVMValueRef expected_value = LLVMBuildSub(g->builder, LLVMConstAllOnes(usize_type_ref),
LLVMConstInt(usize_type_ref, resume_id, false), "");
LLVMValueRef expected_value = LLVMConstSub(LLVMConstAllOnes(usize_type_ref),
LLVMConstInt(usize_type_ref, resume_id, false));
ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, LLVMGetParam(g->cur_fn_val, 1), expected_value, "");
}
LLVMBuildCondBr(g->builder, ok_bit, end_bb, bad_resume_block);
Expand All @@ -2210,11 +2212,13 @@ static LLVMValueRef gen_resume(CodeGen *g, LLVMValueRef fn_val, LLVMValueRef tar
{
LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
if (fn_val == nullptr) {
if (g->anyframe_fn_type == nullptr) {
(void)get_llvm_type(g, get_any_frame_type(g, nullptr));
}
LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, coro_fn_ptr_index, "");
fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, "");
LLVMValueRef fn_val_typed = LLVMBuildLoad(g->builder, fn_ptr_ptr, "");
LLVMValueRef as_int = LLVMBuildPtrToInt(g->builder, fn_val_typed, usize_type_ref, "");
LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false);
LLVMValueRef mask_val = LLVMConstNot(one);
LLVMValueRef as_int_masked = LLVMBuildAnd(g->builder, as_int, mask_val, "");
fn_val = LLVMBuildIntToPtr(g->builder, as_int_masked, LLVMTypeOf(fn_val_typed), "");
}
if (arg_val == nullptr) {
arg_val = LLVMBuildSub(g->builder, LLVMConstAllOnes(usize_type_ref),
Expand All @@ -2226,6 +2230,17 @@ static LLVMValueRef gen_resume(CodeGen *g, LLVMValueRef fn_val, LLVMValueRef tar
return ZigLLVMBuildCall(g->builder, fn_val, args, 2, LLVMFastCallConv, ZigLLVM_FnInlineAuto, "");
}

static LLVMBasicBlockRef gen_suspend_begin(CodeGen *g, const char *name_hint) {
LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, name_hint);
size_t new_block_index = g->cur_resume_block_count;
g->cur_resume_block_count += 1;
LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false);
LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, resume_bb);
LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr);
return resume_bb;
}

static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable,
IrInstructionReturnBegin *instruction)
{
Expand All @@ -2245,12 +2260,7 @@ static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable,
}

// Prepare to be suspended. We might end up not having to suspend though.
LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, "ReturnResume");
size_t new_block_index = g->cur_resume_block_count;
g->cur_resume_block_count += 1;
LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false);
LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, resume_bb);
LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr);
LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "ReturnResume");

LLVMValueRef zero = LLVMConstNull(usize_type_ref);
LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
Expand Down Expand Up @@ -2335,7 +2345,10 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrIns

// We need to resume the caller by tail calling them.
ZigType *any_frame_type = get_any_frame_type(g, ret_type);
LLVMValueRef their_frame_ptr = LLVMBuildIntToPtr(g->builder, g->cur_async_prev_val,
LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false);
LLVMValueRef mask_val = LLVMConstNot(one);
LLVMValueRef masked_prev_val = LLVMBuildAnd(g->builder, g->cur_async_prev_val, mask_val, "");
LLVMValueRef their_frame_ptr = LLVMBuildIntToPtr(g->builder, masked_prev_val,
get_llvm_type(g, any_frame_type), "");
LLVMValueRef call_inst = gen_resume(g, nullptr, their_frame_ptr, ResumeIdReturn, nullptr);
ZigLLVMSetTailCall(call_inst);
Expand Down Expand Up @@ -3945,13 +3958,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
} else if (callee_is_async) {
ZigType *ptr_result_type = get_pointer_to_type(g, src_return_type, true);

LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(g->cur_fn_val, "CallResume");
size_t new_block_index = g->cur_resume_block_count;
g->cur_resume_block_count += 1;
LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false);
LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, call_bb);

LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr);
LLVMBasicBlockRef call_bb = gen_suspend_begin(g, "CallResume");

LLVMValueRef call_inst = gen_resume(g, fn_val, frame_result_loc, ResumeIdCall, nullptr);
ZigLLVMSetTailCall(call_inst);
Expand Down Expand Up @@ -4672,10 +4679,6 @@ static LLVMValueRef ir_render_error_return_trace(CodeGen *g, IrExecutable *execu
return cur_err_ret_trace_val;
}

static LLVMValueRef ir_render_cancel(CodeGen *g, IrExecutable *executable, IrInstructionCancel *instruction) {
zig_panic("TODO cancel");
}

static LLVMAtomicOrdering to_LLVMAtomicOrdering(AtomicOrder atomic_order) {
switch (atomic_order) {
case AtomicOrderUnordered: return LLVMAtomicOrderingUnordered;
Expand Down Expand Up @@ -5416,13 +5419,7 @@ static LLVMValueRef ir_render_assert_non_null(CodeGen *g, IrExecutable *executab
static LLVMValueRef ir_render_suspend_begin(CodeGen *g, IrExecutable *executable,
IrInstructionSuspendBegin *instruction)
{
LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
instruction->resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, "SuspendResume");
size_t new_block_index = g->cur_resume_block_count;
g->cur_resume_block_count += 1;
LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false);
LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, instruction->resume_bb);
LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr);
instruction->resume_bb = gen_suspend_begin(g, "SuspendResume");
return nullptr;
}

Expand All @@ -5436,6 +5433,43 @@ static LLVMValueRef ir_render_suspend_finish(CodeGen *g, IrExecutable *executabl
return nullptr;
}

static LLVMValueRef ir_render_cancel(CodeGen *g, IrExecutable *executable, IrInstructionCancel *instruction) {
LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
LLVMValueRef zero = LLVMConstNull(usize_type_ref);
LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false);

LLVMValueRef target_frame_ptr = ir_llvm_value(g, instruction->frame);
LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "CancelResume");

LLVMValueRef awaiter_val = LLVMBuildPtrToInt(g->builder, g->cur_frame_ptr, usize_type_ref, "");
LLVMValueRef awaiter_ored_val = LLVMBuildOr(g->builder, awaiter_val, one, "");
LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, coro_awaiter_index, "");

LLVMValueRef prev_val = LLVMBuildAtomicRMW(g->builder, LLVMAtomicRMWBinOpXchg, awaiter_ptr, awaiter_ored_val,
LLVMAtomicOrderingRelease, g->is_single_threaded);

LLVMBasicBlockRef complete_suspend_block = LLVMAppendBasicBlock(g->cur_fn_val, "CancelSuspend");
LLVMBasicBlockRef early_return_block = LLVMAppendBasicBlock(g->cur_fn_val, "EarlyReturn");

LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, prev_val, resume_bb, 2);
LLVMAddCase(switch_instr, zero, complete_suspend_block);
LLVMAddCase(switch_instr, all_ones, early_return_block);

LLVMPositionBuilderAtEnd(g->builder, complete_suspend_block);
LLVMBuildRetVoid(g->builder);

LLVMPositionBuilderAtEnd(g->builder, early_return_block);
LLVMValueRef call_inst = gen_resume(g, nullptr, target_frame_ptr, ResumeIdAwaitEarlyReturn, awaiter_ored_val);
ZigLLVMSetTailCall(call_inst);
LLVMBuildRetVoid(g->builder);

LLVMPositionBuilderAtEnd(g->builder, resume_bb);
gen_assert_resume_id(g, &instruction->base, ResumeIdReturn, PanicMsgIdResumedACancelingFn, nullptr);

return nullptr;
}

static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInstructionAwaitGen *instruction) {
LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
LLVMValueRef zero = LLVMConstNull(usize_type_ref);
Expand All @@ -5444,12 +5478,7 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
ZigType *ptr_result_type = get_pointer_to_type(g, result_type, true);

// Prepare to be suspended
LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, "AwaitResume");
size_t new_block_index = g->cur_resume_block_count;
g->cur_resume_block_count += 1;
LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false);
LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, resume_bb);
LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr);
LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "AwaitResume");

// At this point resuming the function will do the correct thing.
// This code is as if it is running inside the suspend block.
Expand Down
55 changes: 48 additions & 7 deletions src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3271,6 +3271,16 @@ static IrInstruction *ir_build_suspend_finish(IrBuilder *irb, Scope *scope, AstN
return &instruction->base;
}

static IrInstruction *ir_build_cancel(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *frame) {
IrInstructionCancel *instruction = ir_build_instruction<IrInstructionCancel>(irb, scope, source_node);
instruction->base.value.type = irb->codegen->builtin_types.entry_void;
instruction->frame = frame;

ir_ref_instruction(frame, irb->current_basic_block);

return &instruction->base;
}

static IrInstruction *ir_build_await_src(IrBuilder *irb, Scope *scope, AstNode *source_node,
IrInstruction *frame, ResultLoc *result_loc)
{
Expand Down Expand Up @@ -7820,11 +7830,26 @@ static IrInstruction *ir_gen_fn_proto(IrBuilder *irb, Scope *parent_scope, AstNo
static IrInstruction *ir_gen_cancel(IrBuilder *irb, Scope *scope, AstNode *node) {
assert(node->type == NodeTypeCancel);

IrInstruction *target_inst = ir_gen_node(irb, node->data.cancel_expr.expr, scope);
if (target_inst == irb->codegen->invalid_instruction)
ZigFn *fn_entry = exec_fn_entry(irb->exec);
if (!fn_entry) {
add_node_error(irb->codegen, node, buf_sprintf("cancel outside function definition"));
return irb->codegen->invalid_instruction;
}
ScopeSuspend *existing_suspend_scope = get_scope_suspend(scope);
if (existing_suspend_scope) {
if (!existing_suspend_scope->reported_err) {
ErrorMsg *msg = add_node_error(irb->codegen, node, buf_sprintf("cannot cancel inside suspend block"));
add_error_note(irb->codegen, msg, existing_suspend_scope->base.source_node, buf_sprintf("suspend block here"));
existing_suspend_scope->reported_err = true;
}
return irb->codegen->invalid_instruction;
}

IrInstruction *operand = ir_gen_node(irb, node->data.cancel_expr.expr, scope);
if (operand == irb->codegen->invalid_instruction)
return irb->codegen->invalid_instruction;

zig_panic("TODO ir_gen_cancel");
return ir_build_cancel(irb, scope, node, operand);
}

static IrInstruction *ir_gen_resume(IrBuilder *irb, Scope *scope, AstNode *node) {
Expand Down Expand Up @@ -23781,10 +23806,6 @@ static IrInstruction *ir_analyze_instruction_tag_type(IrAnalyze *ira, IrInstruct
}
}

static IrInstruction *ir_analyze_instruction_cancel(IrAnalyze *ira, IrInstructionCancel *instruction) {
zig_panic("TODO analyze cancel");
}

static ZigType *ir_resolve_atomic_operand_type(IrAnalyze *ira, IrInstruction *op) {
ZigType *operand_type = ir_resolve_type(ira, op);
if (type_is_invalid(operand_type))
Expand Down Expand Up @@ -24474,6 +24495,26 @@ static IrInstruction *ir_analyze_instruction_suspend_finish(IrAnalyze *ira,
return ir_build_suspend_finish(&ira->new_irb, instruction->base.scope, instruction->base.source_node, begin);
}

static IrInstruction *ir_analyze_instruction_cancel(IrAnalyze *ira, IrInstructionCancel *instruction) {
IrInstruction *frame = instruction->frame->child;
if (type_is_invalid(frame->value.type))
return ira->codegen->invalid_instruction;

ZigType *any_frame_type = get_any_frame_type(ira->codegen, nullptr);
IrInstruction *casted_frame = ir_implicit_cast(ira, frame, any_frame_type);
if (type_is_invalid(casted_frame->value.type))
return ira->codegen->invalid_instruction;

ZigFn *fn_entry = exec_fn_entry(ira->new_irb.exec);
ir_assert(fn_entry != nullptr, &instruction->base);

if (fn_entry->inferred_async_node == nullptr) {
fn_entry->inferred_async_node = instruction->base.source_node;
}

return ir_build_cancel(&ira->new_irb, instruction->base.scope, instruction->base.source_node, casted_frame);
}

static IrInstruction *ir_analyze_instruction_await(IrAnalyze *ira, IrInstructionAwaitSrc *instruction) {
IrInstruction *frame_ptr = instruction->frame->child;
if (type_is_invalid(frame_ptr->value.type))
Expand Down
2 changes: 1 addition & 1 deletion src/ir_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,7 @@ static void ir_print_error_union(IrPrint *irp, IrInstructionErrorUnion *instruct

static void ir_print_cancel(IrPrint *irp, IrInstructionCancel *instruction) {
fprintf(irp->f, "cancel ");
ir_print_other_instruction(irp, instruction->target);
ir_print_other_instruction(irp, instruction->frame);
}

static void ir_print_atomic_rmw(IrPrint *irp, IrInstructionAtomicRmw *instruction) {
Expand Down
Loading

0 comments on commit 7e1fcb5

Please sign in to comment.