Skip to content

EA: can't find cache for @nospecialize-d call #51702

Open

Description

MRE:

julia> const EscapeAnalysis = Core.Compiler.EscapeAnalysis;

julia> let JULIA_DIR = normpath(Sys.BINDIR, "..", "share", "julia")
           # load `EAUtils` module to define the utilities
           include(normpath(JULIA_DIR, "test", "compiler", "EscapeAnalysis", "EAUtils.jl"))
           using .EAUtils
       end

julia> @noinline newsomeany(@nospecialize x) = Some{Any}(x);

julia> @noinline noinline_nospecialize_identity(@nospecialize x) = x;

julia> @inline function makesomenew(x0)
           x1 = newsomeany(x0)
           x2 = noinline_nospecialize_identity(x1)
           return x2
       end;

julia> code_escapes(makesomenew, (Any,))
makesomenew(X x0::Any) in Main at REPL[5]:1
X  1%1 = invoke Main.newsomeany(_2::Any)::Some{Any}
X  │   %2 = invoke Main.noinline_nospecialize_identity(%1::Any)::Some{Any}
◌  └──      return %2

In the above case, EA should propagate the escape information of noinline_nospecialize_identity nicely so that it does not escape %1 excessively. Currently EA is unable to locate a cache for noinline_nospecialize_identity(::Any) because it only caches the escapes for inferred signature method instance noinline_nospecialize_identity(::Some{Any}).

I tried to fix this by allowing the inlining pass to propagate new call info object, that retains inferred MethodInstance so that post-optimization analyses can observe it:

diff --git a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl
index 6a6994d497..c81959be68 100644
--- a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl
+++ b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl
@@ -1093,12 +1093,20 @@ end
 
 # escape statically-resolved call, i.e. `Expr(:invoke, ::MethodInstance, ...)`
 function escape_invoke!(astate::AnalysisState, pc::Int, args::Vector{Any})
-    mi = first(args)::MethodInstance
     first_idx, last_idx = 2, length(args)
     # TODO inspect `astate.ir.stmts[pc][:info]` and use const-prop'ed `InferenceResult` if available
+    info = astate.ir[SSAValue(pc)][:info]
+    if info isa Core.Compiler.InvokeStmtInfo
+        mi = info.inferred
+    else
+        # TODO if this `:invoke` is from serialized IR, it currently doesn't have `InvokeStmtInfo`
+        mi = first(args)::MethodInstance
+    end
     cache = astate.get_escape_cache(mi)
     if cache === nothing
         return add_conservative_changes!(astate, pc, args, 2)
+    elseif cache === false
+        return nothing # no escapes!
     else
         cache = cache::ArgEscapeCache
     end
diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl
index 0a5b5c6580..a396b63951 100644
--- a/base/compiler/ssair/inlining.jl
+++ b/base/compiler/ssair/inlining.jl
@@ -33,11 +33,26 @@ struct SomeCase
 end
 
 struct InvokeCase
-    invoke::MethodInstance
+    invoke::MethodInstance    # method instance to be invoked
+    inferred::MethodInstance  # inferred method instance
     effects::Effects
     info::CallInfo
 end
 
+# TODO Embed this information within `CodeInfo` and allow serialized IR to recover it.
+"""
+    InvokeStmtInfo <: CallInfo
+
+This call info is introduced during the inlining pass to enable post-optimization analyses
+to keep track of inferred `MethodInstance`. After the inlining pass, the IR might not always
+retain inferred `MethodInstance` since it may widen it to `@nospecialize`-d invoke signature.
+"""
+struct InvokeStmtInfo <: CallInfo
+    info::CallInfo
+    inferred::MethodInstance
+    InvokeStmtInfo(@nospecialize(info::CallInfo), inferred::MethodInstance) = new(info, inferred)
+end
+
 struct InliningCase
     sig  # Type
     item # Union{InliningTodo, InvokeCase, ConstantCase}
@@ -614,7 +629,8 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
         elseif isa(case, InvokeCase)
             invoke_stmt = Expr(:invoke, case.invoke, argexprs′...)
             flag = flags_for_effects(case.effects)
-            val = insert_node_here!(compact, NewInstruction(invoke_stmt, typ, case.info, line, flag))
+            invoke_info = InvokeStmtInfo(case.info, case.inferred)
+            val = insert_node_here!(compact, NewInstruction(invoke_stmt, typ, invoke_info, line, flag))
         else
             case = case::ConstantCase
             val = case.val
@@ -833,7 +849,7 @@ function compileable_specialization(mi::MethodInstance, effects::Effects,
     end
     add_inlining_backedge!(et, mi) # to the dispatch lookup
     push!(et.edges, method.sig, mi_invoke) # add_inlining_backedge to the invoke call
-    return InvokeCase(mi_invoke, effects, info)
+    return InvokeCase(mi_invoke, mi, effects, info)
 end
 
 function compileable_specialization(match::MethodMatch, effects::Effects,
@@ -1006,11 +1022,13 @@ function handle_single_case!(todo::Vector{Pair{Int,Any}},
     if isa(case, ConstantCase)
         ir[SSAValue(idx)][:stmt] = case.val
     elseif isa(case, InvokeCase)
-        is_foldable_nothrow(case.effects) && inline_const_if_inlineable!(ir[SSAValue(idx)]) && return nothing
+        inst = ir[SSAValue(idx)]
+        is_foldable_nothrow(case.effects) && inline_const_if_inlineable!(inst) && return nothing
         isinvoke && rewrite_invoke_exprargs!(stmt)
         stmt.head = :invoke
         pushfirst!(stmt.args, case.invoke)
-        ir[SSAValue(idx)][:flag] |= flags_for_effects(case.effects)
+        inst[:flag] |= flags_for_effects(case.effects)
+        inst[:info] = InvokeStmtInfo(inst[:info], case.inferred)
     elseif case === nothing
         # Do, well, nothing
     else
diff --git a/test/compiler/EscapeAnalysis/EscapeAnalysis.jl b/test/compiler/EscapeAnalysis/EscapeAnalysis.jl
index b598388a3d..473039ca48 100644
--- a/test/compiler/EscapeAnalysis/EscapeAnalysis.jl
+++ b/test/compiler/EscapeAnalysis/EscapeAnalysis.jl
@@ -2368,6 +2368,27 @@ function with_self_aliased(from_bb::Int, succs::Vector{Int})
 end
 @test code_escapes(with_self_aliased) isa EAUtils.EscapeResult
 
+# handle `@nospecialize`d `MethodInstance`s nicely
+@noinline newsomeany(@nospecialize x) = Some{Any}(x);
+@noinline noinline_nospecialize_identity(@nospecialize x) = x;
+@inline function makesomenew(x0)
+    x1 = newsomeany(x0)
+    x2 = noinline_nospecialize_identity(x1)
+    return x2
+end
+let result = code_escapes(makesomenew, (Any,))
+    i = only(findall(isinvoke(:newsomeany), result.ir.stmts.stmt))
+    @test has_return_escape(result.state[SSAValue(i)])
+    @test !has_thrown_escape(result.state[SSAValue(i)])
+end
+let result = code_escapes((Any,)) do x0
+        makesomenew(x0)
+    end
+    i = only(findall(isinvoke(:newsomeany), result.ir.stmts.stmt))
+    @test has_return_escape(result.state[SSAValue(i)])
+    @test_broken !has_thrown_escape(result.state[SSAValue(i)])
+end
+
 # accounts for ThrownEscape via potential MethodError
 
 # no method error

But it turns out that this does not work when EA analyzes serialized IR, since the post-inlining call info (InvokeStmtInfo) does not survive through serialization-deserialization cycle (since CodeInfo does not preserve it). So this approach mirrors the challenge seen in #47994, where I tried to embed extended lattice information into CodeInfo. We probably need to come up with a way to efficiently cache and serialize such complex information so that IRCode inflated from cached CodeInfo can recover it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions