From 0cff766ecb4dd9a833607208275ceae5905c49a3 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Wed, 3 Mar 2021 11:43:48 -0500 Subject: [PATCH] Lower Tapir after inlining Since optimized code is cached and used for inlining, we need to lower Tapir after optimized code is cached, to enable Tapir for IPO (i.e., we want to cache the optimized code that still has `detach` etc.). This patch moves Tapir lowering out of the standard `run_passes` optimization phase, to avoid caching the code that includes lowered code. Instead, lowering happens inside of `jl_emit_code` just before emitting LLVM IR. --- base/compiler/compiler.jl | 1 + base/compiler/optimize.jl | 2 ++ base/compiler/ssair/driver.jl | 1 - base/compiler/ssair/verify.jl | 7 ++++- base/compiler/tapirpasses.jl | 38 ++++++++++++++++++++++++- base/compiler/typeinfer.jl | 1 - src/codegen.cpp | 52 +++++++++++++++++++++++++++++++++++ src/gf.c | 7 +++++ src/jl_exported_funcs.inc | 1 + src/julia_internal.h | 1 + src/staticdata.c | 3 +- 11 files changed, 109 insertions(+), 5 deletions(-) diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 67a642f0f40b1..503d162f436f2 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -141,6 +141,7 @@ include("compiler/tapirpasses.jl") include("compiler/bootstrap.jl") ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel) +ccall(:jl_set_lower_tapir_func, Cvoid, (Any,), lower_tapir) include("compiler/parsing.jl") Core.eval(Core, :(_parse = Compiler.fl_parse)) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index ac60a6c58f130..05bf477422481 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -62,6 +62,8 @@ mutable struct OptimizationState nssavalues = src.ssavaluetypes if nssavalues isa Int src.ssavaluetypes = Any[ Any for i = 1:nssavalues ] + else + nssavalues = length(src.ssavaluetypes) end nslots = length(src.slotflags) slottypes = src.slottypes diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index 482b058026c4f..6bba34518a6ba 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -140,7 +140,6 @@ function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState) @timeit "type lift" ir = type_lift_pass!(ir) @timeit "compact 3" ir = compact!(ir) #@Base.show ir - @timeit "tapir" ir = lower_tapir!(ir) if JLOptions().debug_level == 2 @timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable)) end diff --git a/base/compiler/ssair/verify.jl b/base/compiler/ssair/verify.jl index 5bf06d2994b8b..3f4bf83380538 100644 --- a/base/compiler/ssair/verify.jl +++ b/base/compiler/ssair/verify.jl @@ -28,7 +28,12 @@ function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int, end end else - if !dominates(domtree, def_bb, use_bb) && !(bb_unreachable(domtree, def_bb) && bb_unreachable(domtree, use_bb)) + # TODO: hoist out is_sequential(ir) + if ( + is_sequential(ir) && # ignore this check before Tapir lowering + !dominates(domtree, def_bb, use_bb) && + !(bb_unreachable(domtree, def_bb) && bb_unreachable(domtree, use_bb)) + ) # At the moment, we allow GC preserve tokens outside the standard domination notion #@Base.show ir @verify_error "Basic Block $def_bb does not dominate block $use_bb (tried to use value $(op.id))" diff --git a/base/compiler/tapirpasses.jl b/base/compiler/tapirpasses.jl index 9e12cc803eae2..29c9f923762d9 100644 --- a/base/compiler/tapirpasses.jl +++ b/base/compiler/tapirpasses.jl @@ -977,6 +977,35 @@ function opaque_closure_method_from_ssair(ir::IRCode) ) end +is_sequential(src::CodeInfo) = all(x -> !(x isa DetachNode), src.code) + +function lower_tapir(interp::AbstractInterpreter, linfo::MethodInstance, ci::CodeInfo) + ccall(:jl_breakpoint, Cvoid, (Any,), ci) + is_sequential(ci) && return remove_tapir(ci) + + # Ref: _typeinf(interp::AbstractInterpreter, frame::InferenceState) + params = OptimizationParams(interp) + opt = OptimizationState(linfo, copy(ci), params, interp) + nargs = Int(opt.nargs) - 1 # Ref: optimize(interp, opt, params) + + # Ref: run_passes + preserve_coverage = coverage_enabled(opt.mod) + ir = convert_to_ircode(ci, copy_exprargs(ci.code), preserve_coverage, nargs, opt) + ir = slot2reg(ir, ci, nargs, opt) + @timeit "tapir" ir = lower_tapir!(ir) + if JLOptions().debug_level == 2 + @timeit "verify tapir" (verify_ir(ir); verify_linetable(ir.linetable)) + end + + finish(opt, params, ir, Any) # Ref: optimize(interp, opt, params) + finish(opt.src, interp) # Ref: _typeinf(interp, frame) + + return remove_tapir!(opt.src) +end + +lower_tapir(linfo::MethodInstance, ci::CodeInfo) = + lower_tapir(NativeInterpreter(), linfo, ci) + """ remove_tapir!(src::CodeInfo) remove_tapir!(_::Any) @@ -992,6 +1021,13 @@ function remove_tapir!(src::CodeInfo) src.code[i] = nothing end end - return + return src end remove_tapir!(::Any) = nothing + +function remove_tapir(src::CodeInfo) + any(src.code) do x + (x isa Union{DetachNode,ReattachNode,SyncNode}) || isexpr(x, :syncregion) + end && return remove_tapir!(copy(src)) # warn? + return src +end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index ca700904d201e..0707162e1b67b 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -258,7 +258,6 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState) caller.src = nothing end caller.valid_worlds = opt.inlining.et.valid_worlds[] - remove_tapir!(opt.src) end end end diff --git a/src/codegen.cpp b/src/codegen.cpp index e9e48bf9da1b3..5628eb4bc671b 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -7364,6 +7364,28 @@ static std::pair, jl_llvm_functions_t> void jl_add_code_in_flight(StringRef name, jl_code_instance_t *codeinst, const DataLayout &DL); +// Function `code_info_has_tapir(jl_code_info_t *src)` checks if `src` contains +// Tapir nodes such as detach. This functions is written in C to avoid recursing +// into `lower_tapir` while compiling itself. +static int code_info_has_tapir(jl_code_info_t *src) { + for (size_t i = 0; i < jl_array_len(src->code); ++i) { + jl_value_t *stmt = jl_array_ptr_ref(src->code, i); + if (jl_is_detachnode(stmt)) { + return 1; + } else if (jl_is_reattachnode(stmt)) { + return 1; + } else if (jl_is_syncnode(stmt)) { + return 1; + } else if (jl_is_expr(stmt)) { + jl_expr_t *expr = (jl_expr_t*)stmt; + if (expr->head == syncregion_sym) { + return 1; + } + } + } + return 0; +} + JL_GCC_IGNORE_START("-Wclobbered") jl_compile_result_t jl_emit_code( jl_method_instance_t *li, @@ -7372,6 +7394,36 @@ jl_compile_result_t jl_emit_code( jl_codegen_params_t ¶ms) { JL_TIMING(CODEGEN); + // ASK: Is there a better place/way to call `lower_tapir`? + if (jl_lower_tapir_func && jl_typeinf_world && code_info_has_tapir(src)) { + // Lower task prallel IR (detach etc.) to the calls to the parallel task + // runtime. This is done after optimized Julia IR is cached, so that + // parallel IR can be optimzied across Julia functions (when inlined). + // Since `jl_emit_codeinst` can cache `CodeInfo`, this transformation + // cannot happen before it. + jl_code_info_t *src0 = src; + jl_ptls_t ptls = jl_get_ptls_states(); + jl_value_t **fargs; + size_t last_age = ptls->world_age; + ptls->world_age = jl_typeinf_world; + JL_GC_PUSHARGS(fargs, 3); + fargs[0] = jl_lower_tapir_func; + fargs[1] = (jl_value_t *)li; + fargs[2] = (jl_value_t *)src; + JL_TRY { + src = (jl_code_info_t *)jl_apply(fargs, 3); + } + JL_CATCH { + jl_printf((JL_STREAM *)STDERR_FILENO, + "Internal error: encountered unexpected error in runtime:\n"); + jl_static_show((JL_STREAM *)STDERR_FILENO, jl_current_exception()); + jl_printf((JL_STREAM *)STDERR_FILENO, "\n"); + jlbacktrace(); // written to STDERR_FILENO + src = src0; + } + JL_GC_POP(); + ptls->world_age = last_age; + } // caller must hold codegen_lock jl_llvm_functions_t decls = {}; std::unique_ptr m; diff --git a/src/gf.c b/src/gf.c index 0c2cd7472e0d0..c0f4e743639c9 100644 --- a/src/gf.c +++ b/src/gf.c @@ -518,6 +518,13 @@ JL_DLLEXPORT void jl_set_typeinf_func(jl_value_t *f) } } +jl_function_t *jl_lower_tapir_func = NULL; + +JL_DLLEXPORT void jl_set_lower_tapir_func(jl_value_t *f) +{ + jl_lower_tapir_func = (jl_function_t*)f; +} + static int very_general_type(jl_value_t *t) { return (t == (jl_value_t*)jl_any_type || jl_types_equal(t, (jl_value_t*)jl_type_type)); diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index 3e16120453c76..9d6ce819a286d 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -557,6 +557,7 @@ XX(jl_set_errno) \ XX(jl_set_global) \ XX(jl_set_istopmod) \ + XX(jl_set_lower_tapir_func) \ XX(jl_set_module_compile) \ XX(jl_set_module_infer) \ XX(jl_set_module_nospecialize) \ diff --git a/src/julia_internal.h b/src/julia_internal.h index 1c3f2bcec4d61..9a6fba13df74c 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -166,6 +166,7 @@ extern jl_array_t *_jl_debug_method_invalidation JL_GLOBALLY_ROOTED; extern size_t jl_page_size; extern jl_function_t *jl_typeinf_func; extern size_t jl_typeinf_world; +extern jl_function_t *jl_lower_tapir_func; extern jl_typemap_entry_t *call_cache[N_CALL_CACHE] JL_GLOBALLY_ROOTED; extern jl_array_t *jl_all_methods JL_GLOBALLY_ROOTED; diff --git a/src/staticdata.c b/src/staticdata.c index 531ea4bf41005..849b85aa65883 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -30,7 +30,7 @@ extern "C" { // TODO: put WeakRefs on the weak_refs list during deserialization // TODO: handle finalizers -#define NUM_TAGS 149 +#define NUM_TAGS 150 // An array of references that need to be restored from the sysimg // This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C. @@ -170,6 +170,7 @@ jl_value_t **const*const get_tags(void) { INSERT_TAG(jl_main_module); INSERT_TAG(jl_top_module); INSERT_TAG(jl_typeinf_func); + INSERT_TAG(jl_lower_tapir_func); INSERT_TAG(jl_type_type_mt); INSERT_TAG(jl_nonfunction_mt);