Skip to content

Commit

Permalink
Lower Tapir after inlining
Browse files Browse the repository at this point in the history
Since optimized code is cached and used for inlining, we need to lower
Tapir after optimized code is cached, to enable Tapir for IPO (i.e.,
we want to cache the optimized code that still has `detach` etc.).
This patch moves Tapir lowering out of the standard `run_passes`
optimization phase, to avoid caching the code that includes lowered
code. Instead, lowering happens inside of `jl_emit_code` just before
emitting LLVM IR.
  • Loading branch information
tkf committed Mar 3, 2021
1 parent 832840e commit 0cff766
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 5 deletions.
1 change: 1 addition & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ include("compiler/tapirpasses.jl")

include("compiler/bootstrap.jl")
ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)
ccall(:jl_set_lower_tapir_func, Cvoid, (Any,), lower_tapir)

include("compiler/parsing.jl")
Core.eval(Core, :(_parse = Compiler.fl_parse))
Expand Down
2 changes: 2 additions & 0 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ mutable struct OptimizationState
nssavalues = src.ssavaluetypes
if nssavalues isa Int
src.ssavaluetypes = Any[ Any for i = 1:nssavalues ]
else
nssavalues = length(src.ssavaluetypes)
end
nslots = length(src.slotflags)
slottypes = src.slottypes
Expand Down
1 change: 0 additions & 1 deletion base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState)
@timeit "type lift" ir = type_lift_pass!(ir)
@timeit "compact 3" ir = compact!(ir)
#@Base.show ir
@timeit "tapir" ir = lower_tapir!(ir)
if JLOptions().debug_level == 2
@timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable))
end
Expand Down
7 changes: 6 additions & 1 deletion base/compiler/ssair/verify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int,
end
end
else
if !dominates(domtree, def_bb, use_bb) && !(bb_unreachable(domtree, def_bb) && bb_unreachable(domtree, use_bb))
# TODO: hoist out is_sequential(ir)
if (
is_sequential(ir) && # ignore this check before Tapir lowering
!dominates(domtree, def_bb, use_bb) &&
!(bb_unreachable(domtree, def_bb) && bb_unreachable(domtree, use_bb))
)
# At the moment, we allow GC preserve tokens outside the standard domination notion
#@Base.show ir
@verify_error "Basic Block $def_bb does not dominate block $use_bb (tried to use value $(op.id))"
Expand Down
38 changes: 37 additions & 1 deletion base/compiler/tapirpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,35 @@ function opaque_closure_method_from_ssair(ir::IRCode)
)
end

is_sequential(src::CodeInfo) = all(x -> !(x isa DetachNode), src.code)

function lower_tapir(interp::AbstractInterpreter, linfo::MethodInstance, ci::CodeInfo)
ccall(:jl_breakpoint, Cvoid, (Any,), ci)
is_sequential(ci) && return remove_tapir(ci)

# Ref: _typeinf(interp::AbstractInterpreter, frame::InferenceState)
params = OptimizationParams(interp)
opt = OptimizationState(linfo, copy(ci), params, interp)
nargs = Int(opt.nargs) - 1 # Ref: optimize(interp, opt, params)

# Ref: run_passes
preserve_coverage = coverage_enabled(opt.mod)
ir = convert_to_ircode(ci, copy_exprargs(ci.code), preserve_coverage, nargs, opt)
ir = slot2reg(ir, ci, nargs, opt)
@timeit "tapir" ir = lower_tapir!(ir)
if JLOptions().debug_level == 2
@timeit "verify tapir" (verify_ir(ir); verify_linetable(ir.linetable))
end

finish(opt, params, ir, Any) # Ref: optimize(interp, opt, params)
finish(opt.src, interp) # Ref: _typeinf(interp, frame)

return remove_tapir!(opt.src)
end

lower_tapir(linfo::MethodInstance, ci::CodeInfo) =
lower_tapir(NativeInterpreter(), linfo, ci)

"""
remove_tapir!(src::CodeInfo)
remove_tapir!(_::Any)
Expand All @@ -992,6 +1021,13 @@ function remove_tapir!(src::CodeInfo)
src.code[i] = nothing
end
end
return
return src
end
remove_tapir!(::Any) = nothing

function remove_tapir(src::CodeInfo)
any(src.code) do x
(x isa Union{DetachNode,ReattachNode,SyncNode}) || isexpr(x, :syncregion)
end && return remove_tapir!(copy(src)) # warn?
return src
end
1 change: 0 additions & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
caller.src = nothing
end
caller.valid_worlds = opt.inlining.et.valid_worlds[]
remove_tapir!(opt.src)
end
end
end
Expand Down
52 changes: 52 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7364,6 +7364,28 @@ static std::pair<std::unique_ptr<Module>, jl_llvm_functions_t>

void jl_add_code_in_flight(StringRef name, jl_code_instance_t *codeinst, const DataLayout &DL);

// Function `code_info_has_tapir(jl_code_info_t *src)` checks if `src` contains
// Tapir nodes such as detach. This functions is written in C to avoid recursing
// into `lower_tapir` while compiling itself.
static int code_info_has_tapir(jl_code_info_t *src) {
for (size_t i = 0; i < jl_array_len(src->code); ++i) {
jl_value_t *stmt = jl_array_ptr_ref(src->code, i);
if (jl_is_detachnode(stmt)) {
return 1;
} else if (jl_is_reattachnode(stmt)) {
return 1;
} else if (jl_is_syncnode(stmt)) {
return 1;
} else if (jl_is_expr(stmt)) {
jl_expr_t *expr = (jl_expr_t*)stmt;
if (expr->head == syncregion_sym) {
return 1;
}
}
}
return 0;
}

JL_GCC_IGNORE_START("-Wclobbered")
jl_compile_result_t jl_emit_code(
jl_method_instance_t *li,
Expand All @@ -7372,6 +7394,36 @@ jl_compile_result_t jl_emit_code(
jl_codegen_params_t &params)
{
JL_TIMING(CODEGEN);
// ASK: Is there a better place/way to call `lower_tapir`?
if (jl_lower_tapir_func && jl_typeinf_world && code_info_has_tapir(src)) {
// Lower task prallel IR (detach etc.) to the calls to the parallel task
// runtime. This is done after optimized Julia IR is cached, so that
// parallel IR can be optimzied across Julia functions (when inlined).
// Since `jl_emit_codeinst` can cache `CodeInfo`, this transformation
// cannot happen before it.
jl_code_info_t *src0 = src;
jl_ptls_t ptls = jl_get_ptls_states();
jl_value_t **fargs;
size_t last_age = ptls->world_age;
ptls->world_age = jl_typeinf_world;
JL_GC_PUSHARGS(fargs, 3);
fargs[0] = jl_lower_tapir_func;
fargs[1] = (jl_value_t *)li;
fargs[2] = (jl_value_t *)src;
JL_TRY {
src = (jl_code_info_t *)jl_apply(fargs, 3);
}
JL_CATCH {
jl_printf((JL_STREAM *)STDERR_FILENO,
"Internal error: encountered unexpected error in runtime:\n");
jl_static_show((JL_STREAM *)STDERR_FILENO, jl_current_exception());
jl_printf((JL_STREAM *)STDERR_FILENO, "\n");
jlbacktrace(); // written to STDERR_FILENO
src = src0;
}
JL_GC_POP();
ptls->world_age = last_age;
}
// caller must hold codegen_lock
jl_llvm_functions_t decls = {};
std::unique_ptr<Module> m;
Expand Down
7 changes: 7 additions & 0 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,13 @@ JL_DLLEXPORT void jl_set_typeinf_func(jl_value_t *f)
}
}

jl_function_t *jl_lower_tapir_func = NULL;

JL_DLLEXPORT void jl_set_lower_tapir_func(jl_value_t *f)
{
jl_lower_tapir_func = (jl_function_t*)f;
}

static int very_general_type(jl_value_t *t)
{
return (t == (jl_value_t*)jl_any_type || jl_types_equal(t, (jl_value_t*)jl_type_type));
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@
XX(jl_set_errno) \
XX(jl_set_global) \
XX(jl_set_istopmod) \
XX(jl_set_lower_tapir_func) \
XX(jl_set_module_compile) \
XX(jl_set_module_infer) \
XX(jl_set_module_nospecialize) \
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ extern jl_array_t *_jl_debug_method_invalidation JL_GLOBALLY_ROOTED;
extern size_t jl_page_size;
extern jl_function_t *jl_typeinf_func;
extern size_t jl_typeinf_world;
extern jl_function_t *jl_lower_tapir_func;
extern jl_typemap_entry_t *call_cache[N_CALL_CACHE] JL_GLOBALLY_ROOTED;
extern jl_array_t *jl_all_methods JL_GLOBALLY_ROOTED;

Expand Down
3 changes: 2 additions & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ extern "C" {
// TODO: put WeakRefs on the weak_refs list during deserialization
// TODO: handle finalizers

#define NUM_TAGS 149
#define NUM_TAGS 150

// An array of references that need to be restored from the sysimg
// This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C.
Expand Down Expand Up @@ -170,6 +170,7 @@ jl_value_t **const*const get_tags(void) {
INSERT_TAG(jl_main_module);
INSERT_TAG(jl_top_module);
INSERT_TAG(jl_typeinf_func);
INSERT_TAG(jl_lower_tapir_func);
INSERT_TAG(jl_type_type_mt);
INSERT_TAG(jl_nonfunction_mt);

Expand Down

0 comments on commit 0cff766

Please sign in to comment.