Skip to content

Commit

Permalink
error return trace across suspend points
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewrk committed Aug 6, 2019
1 parent 17199b0 commit 966c9ea
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/all_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1737,6 +1737,7 @@ struct CodeGen {
LLVMValueRef stacksave_fn_val;
LLVMValueRef stackrestore_fn_val;
LLVMValueRef write_register_fn_val;
LLVMValueRef merge_err_ret_traces_fn_val;
LLVMValueRef sp_md_node;
LLVMValueRef err_name_table;
LLVMValueRef safety_crash_err_fn;
Expand Down
151 changes: 146 additions & 5 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2026,18 +2026,159 @@ void walk_function_params(CodeGen *g, ZigType *fn_type, FnWalk *fn_walk) {
}
}

static LLVMValueRef get_merge_err_ret_traces_fn_val(CodeGen *g) {
if (g->merge_err_ret_traces_fn_val)
return g->merge_err_ret_traces_fn_val;

assert(g->stack_trace_type != nullptr);

LLVMTypeRef param_types[] = {
get_llvm_type(g, get_ptr_to_stack_trace_type(g)),
get_llvm_type(g, get_ptr_to_stack_trace_type(g)),
};
LLVMTypeRef fn_type_ref = LLVMFunctionType(LLVMVoidType(), param_types, 2, false);

Buf *fn_name = get_mangled_name(g, buf_create_from_str("__zig_merge_error_return_traces"), false);
LLVMValueRef fn_val = LLVMAddFunction(g->module, buf_ptr(fn_name), fn_type_ref);
LLVMSetLinkage(fn_val, LLVMInternalLinkage);
LLVMSetFunctionCallConv(fn_val, get_llvm_cc(g, CallingConventionUnspecified));
addLLVMFnAttr(fn_val, "nounwind");
add_uwtable_attr(g, fn_val);
// Error return trace memory is in the stack, which is impossible to be at address 0
// on any architecture.
addLLVMArgAttr(fn_val, (unsigned)0, "nonnull");
addLLVMArgAttr(fn_val, (unsigned)0, "noalias");
addLLVMArgAttr(fn_val, (unsigned)0, "writeonly");
// Error return trace memory is in the stack, which is impossible to be at address 0
// on any architecture.
addLLVMArgAttr(fn_val, (unsigned)1, "nonnull");
addLLVMArgAttr(fn_val, (unsigned)1, "noalias");
addLLVMArgAttr(fn_val, (unsigned)1, "readonly");
if (g->build_mode == BuildModeDebug) {
ZigLLVMAddFunctionAttr(fn_val, "no-frame-pointer-elim", "true");
ZigLLVMAddFunctionAttr(fn_val, "no-frame-pointer-elim-non-leaf", nullptr);
}

// this is above the ZigLLVMClearCurrentDebugLocation
LLVMValueRef add_error_return_trace_addr_fn_val = get_add_error_return_trace_addr_fn(g);

LLVMBasicBlockRef entry_block = LLVMAppendBasicBlock(fn_val, "Entry");
LLVMBasicBlockRef prev_block = LLVMGetInsertBlock(g->builder);
LLVMValueRef prev_debug_location = LLVMGetCurrentDebugLocation(g->builder);
LLVMPositionBuilderAtEnd(g->builder, entry_block);
ZigLLVMClearCurrentDebugLocation(g->builder);

// var frame_index: usize = undefined;
// var frames_left: usize = undefined;
// if (src_stack_trace.index < src_stack_trace.instruction_addresses.len) {
// frame_index = 0;
// frames_left = src_stack_trace.index;
// if (frames_left == 0) return;
// } else {
// frame_index = (src_stack_trace.index + 1) % src_stack_trace.instruction_addresses.len;
// frames_left = src_stack_trace.instruction_addresses.len;
// }
// while (true) {
// __zig_add_err_ret_trace_addr(dest_stack_trace, src_stack_trace.instruction_addresses[frame_index]);
// frames_left -= 1;
// if (frames_left == 0) return;
// frame_index = (frame_index + 1) % src_stack_trace.instruction_addresses.len;
// }
LLVMBasicBlockRef return_block = LLVMAppendBasicBlock(fn_val, "Return");

LLVMValueRef frame_index_ptr = LLVMBuildAlloca(g->builder, g->builtin_types.entry_usize->llvm_type, "frame_index");
LLVMValueRef frames_left_ptr = LLVMBuildAlloca(g->builder, g->builtin_types.entry_usize->llvm_type, "frames_left");

LLVMValueRef dest_stack_trace_ptr = LLVMGetParam(fn_val, 0);
LLVMValueRef src_stack_trace_ptr = LLVMGetParam(fn_val, 1);

size_t src_index_field_index = g->stack_trace_type->data.structure.fields[0].gen_index;
size_t src_addresses_field_index = g->stack_trace_type->data.structure.fields[1].gen_index;
LLVMValueRef src_index_field_ptr = LLVMBuildStructGEP(g->builder, src_stack_trace_ptr,
(unsigned)src_index_field_index, "");
LLVMValueRef src_addresses_field_ptr = LLVMBuildStructGEP(g->builder, src_stack_trace_ptr,
(unsigned)src_addresses_field_index, "");
ZigType *slice_type = g->stack_trace_type->data.structure.fields[1].type_entry;
size_t ptr_field_index = slice_type->data.structure.fields[slice_ptr_index].gen_index;
LLVMValueRef src_ptr_field_ptr = LLVMBuildStructGEP(g->builder, src_addresses_field_ptr, (unsigned)ptr_field_index, "");
size_t len_field_index = slice_type->data.structure.fields[slice_len_index].gen_index;
LLVMValueRef src_len_field_ptr = LLVMBuildStructGEP(g->builder, src_addresses_field_ptr, (unsigned)len_field_index, "");
LLVMValueRef src_index_val = LLVMBuildLoad(g->builder, src_index_field_ptr, "");
LLVMValueRef src_ptr_val = LLVMBuildLoad(g->builder, src_ptr_field_ptr, "");
LLVMValueRef src_len_val = LLVMBuildLoad(g->builder, src_len_field_ptr, "");
LLVMValueRef no_wrap_bit = LLVMBuildICmp(g->builder, LLVMIntULT, src_index_val, src_len_val, "");
LLVMBasicBlockRef no_wrap_block = LLVMAppendBasicBlock(fn_val, "NoWrap");
LLVMBasicBlockRef yes_wrap_block = LLVMAppendBasicBlock(fn_val, "YesWrap");
LLVMBasicBlockRef loop_block = LLVMAppendBasicBlock(fn_val, "Loop");
LLVMBuildCondBr(g->builder, no_wrap_bit, no_wrap_block, yes_wrap_block);

LLVMPositionBuilderAtEnd(g->builder, no_wrap_block);
LLVMValueRef usize_zero = LLVMConstNull(g->builtin_types.entry_usize->llvm_type);
LLVMBuildStore(g->builder, usize_zero, frame_index_ptr);
LLVMBuildStore(g->builder, src_index_val, frames_left_ptr);
LLVMValueRef frames_left_eq_zero_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, src_index_val, usize_zero, "");
LLVMBuildCondBr(g->builder, frames_left_eq_zero_bit, return_block, loop_block);

LLVMPositionBuilderAtEnd(g->builder, yes_wrap_block);
LLVMValueRef usize_one = LLVMConstInt(g->builtin_types.entry_usize->llvm_type, 1, false);
LLVMValueRef plus_one = LLVMBuildNUWAdd(g->builder, src_index_val, usize_one, "");
LLVMValueRef mod_len = LLVMBuildURem(g->builder, plus_one, src_len_val, "");
LLVMBuildStore(g->builder, mod_len, frame_index_ptr);
LLVMBuildStore(g->builder, src_len_val, frames_left_ptr);
LLVMBuildBr(g->builder, loop_block);

LLVMPositionBuilderAtEnd(g->builder, loop_block);
LLVMValueRef ptr_index = LLVMBuildLoad(g->builder, frame_index_ptr, "");
LLVMValueRef addr_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr_val, &ptr_index, 1, "");
LLVMValueRef this_addr_val = LLVMBuildLoad(g->builder, addr_ptr, "");
LLVMValueRef args[] = {dest_stack_trace_ptr, this_addr_val};
ZigLLVMBuildCall(g->builder, add_error_return_trace_addr_fn_val, args, 2, get_llvm_cc(g, CallingConventionUnspecified), ZigLLVM_FnInlineAlways, "");
LLVMValueRef prev_frames_left = LLVMBuildLoad(g->builder, frames_left_ptr, "");
LLVMValueRef new_frames_left = LLVMBuildNUWSub(g->builder, prev_frames_left, usize_one, "");
LLVMValueRef done_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, new_frames_left, usize_zero, "");
LLVMBasicBlockRef continue_block = LLVMAppendBasicBlock(fn_val, "Continue");
LLVMBuildCondBr(g->builder, done_bit, return_block, continue_block);

LLVMPositionBuilderAtEnd(g->builder, return_block);
LLVMBuildRetVoid(g->builder);

LLVMPositionBuilderAtEnd(g->builder, continue_block);
LLVMBuildStore(g->builder, new_frames_left, frames_left_ptr);
LLVMValueRef prev_index = LLVMBuildLoad(g->builder, frame_index_ptr, "");
LLVMValueRef index_plus_one = LLVMBuildNUWAdd(g->builder, prev_index, usize_one, "");
LLVMValueRef index_mod_len = LLVMBuildURem(g->builder, index_plus_one, src_len_val, "");
LLVMBuildStore(g->builder, index_mod_len, frame_index_ptr);
LLVMBuildBr(g->builder, loop_block);

LLVMPositionBuilderAtEnd(g->builder, prev_block);
if (!g->strip_debug_symbols) {
LLVMSetCurrentDebugLocation(g->builder, prev_debug_location);
}

g->merge_err_ret_traces_fn_val = fn_val;
return fn_val;

}
static LLVMValueRef ir_render_save_err_ret_addr(CodeGen *g, IrExecutable *executable,
IrInstructionSaveErrRetAddr *save_err_ret_addr_instruction)
{
assert(g->have_err_ret_tracing);

LLVMValueRef return_err_fn = get_return_err_fn(g);
LLVMValueRef args[] = {
get_cur_err_ret_trace_val(g, save_err_ret_addr_instruction->base.scope),
};
LLVMValueRef call_instruction = ZigLLVMBuildCall(g->builder, return_err_fn, args, 1,
LLVMValueRef my_err_trace_val = get_cur_err_ret_trace_val(g, save_err_ret_addr_instruction->base.scope);
ZigLLVMBuildCall(g->builder, return_err_fn, &my_err_trace_val, 1,
get_llvm_cc(g, CallingConventionUnspecified), ZigLLVM_FnInlineAuto, "");
return call_instruction;

if (fn_is_async(g->cur_fn) && g->cur_fn->calls_or_awaits_errorable_fn &&
codegen_fn_has_err_ret_tracing_arg(g, g->cur_fn->type_entry->data.fn.fn_type_id.return_type))
{
LLVMValueRef dest_trace_ptr = LLVMBuildLoad(g->builder, g->cur_err_ret_trace_val_arg, "");
LLVMValueRef args[] = { dest_trace_ptr, my_err_trace_val };
ZigLLVMBuildCall(g->builder, get_merge_err_ret_traces_fn_val(g), args, 2,
get_llvm_cc(g, CallingConventionUnspecified), ZigLLVM_FnInlineAuto, "");
}

return nullptr;
}

static void gen_assert_resume_id(CodeGen *g, IrInstruction *source_instr, ResumeId resume_id, PanicMsgId msg_id,
Expand Down
16 changes: 11 additions & 5 deletions test/runtime_safety.zig
Original file line number Diff line number Diff line change
Expand Up @@ -544,23 +544,29 @@ pub fn addCases(cases: *tests.CompareOutputContext) void {
\\ std.os.exit(126);
\\}
\\
\\var failing_frame: @Frame(failing) = undefined;
\\
\\pub fn main() void {
\\ const p = nonFailing();
\\ resume p;
\\ const p2 = async<std.debug.global_allocator> printTrace(p) catch unreachable;
\\ cancel p2;
\\ const p2 = async printTrace(p);
\\}
\\
\\fn nonFailing() promise->anyerror!void {
\\ return async<std.debug.global_allocator> failing() catch unreachable;
\\fn nonFailing() anyframe->anyerror!void {
\\ failing_frame = async failing();
\\ return &failing_frame;
\\}
\\
\\async fn failing() anyerror!void {
\\ suspend;
\\ return second();
\\}
\\
\\async fn second() anyerror!void {
\\ return error.Fail;
\\}
\\
\\async fn printTrace(p: promise->anyerror!void) void {
\\async fn printTrace(p: anyframe->anyerror!void) void {
\\ (await p) catch unreachable;
\\}
);
Expand Down

0 comments on commit 966c9ea

Please sign in to comment.