Description
Opening an issue to track JuliaGPU/GPUCompiler.jl#384, and discuss how we can get something like #51080 merged.
Context: The GPU stack is a heavy user of overlay methods, to make functionality GPU-compatible or otherwise provide a GPU-specific implementation. One area where we need such overlays, are the outlined throw_XXX
methods that throw objects requiring allocations. For example, InexactError
contains untyped fields and as such is currently GPU incompatible, so we overlay Core.throw_inexacterror
with a simplified version that only throws a message.
Most of the time, overlay methods are unsafe to execute on the host, e.g., because they use GPU-specific functionality. AFAIU, that's why concrete evaluation of them is prohibited. However, because we overlay very common core functionality, that prevents lots of functionality being optimized and frequently results in GPU-incompatible code being generated.
An example that resulted from replacing @pure
with effects modeling (#44776):
using CUDA
f(x) = Float32(x, RoundDown)
InteractiveUtils.code_llvm(f, Tuple{Irrational{:π}})
CUDA.code_llvm(f, Tuple{Irrational{:π}})
define float @julia_f_6003() #0 {
top:
ret float 0x400921FB40000000
}
define float @julia_f_9717() local_unnamed_addr {
top:
; ┌ @ irrationals.jl:69 within `Type`
; │┌ @ mpfr.jl:1099 within `setprecision`
%0 = call fastcc float @julia__setprecision_25_9723(i64 signext 256)
; └└
ret float %0
}
Or, for a MWE without the GPU stack:
const CC = Core.Compiler
using Core: MethodInstance, CodeInstance, CodeInfo, MethodTable
## code instance cache
struct CodeCache
dict::IdDict{MethodInstance,Vector{CodeInstance}}
CodeCache() = new(IdDict{MethodInstance,Vector{CodeInstance}}())
end
function CC.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
cis = get!(cache.dict, mi, CodeInstance[])
push!(cis, ci)
end
## world view of the cache
function CC.haskey(wvc::CC.WorldView{CodeCache}, mi::MethodInstance)
CC.get(wvc, mi, nothing) !== nothing
end
function CC.get(wvc::CC.WorldView{CodeCache}, mi::MethodInstance, default)
# check the cache
for ci in get!(wvc.cache.dict, mi, CodeInstance[])
if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
# TODO: if (code && (code == jl_nothing || jl_ir_flag_inferred((jl_array_t*)code)))
src = if ci.inferred isa Vector{UInt8}
ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
mi.def, C_NULL, ci.inferred)
else
ci.inferred
end
return ci
end
end
return default
end
function CC.getindex(wvc::CC.WorldView{CodeCache}, mi::MethodInstance)
r = CC.get(wvc, mi, nothing)
r === nothing && throw(KeyError(mi))
return r::CodeInstance
end
function CC.setindex!(wvc::CC.WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
src = if ci.inferred isa Vector{UInt8}
ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
mi.def, C_NULL, ci.inferred)
else
ci.inferred
end
CC.setindex!(wvc.cache, ci, mi)
end
## interpreter
if isdefined(CC, :CachedMethodTable)
const ExternalMethodTableView = CC.CachedMethodTable{CC.OverlayMethodTable}
get_method_table_view(world::UInt, mt::MethodTable) =
CC.CachedMethodTable(CC.OverlayMethodTable(world, mt))
else
const ExternalMethodTableView = CC.OverlayMethodTable
get_method_table_view(world::UInt, mt::MethodTable) = CC.OverlayMethodTable(world, mt)
end
struct ExternalInterpreter <: CC.AbstractInterpreter
world::UInt
method_table::ExternalMethodTableView
code_cache
inf_cache::Vector{CC.InferenceResult}
end
function ExternalInterpreter(world::UInt=Base.get_world_counter(); method_table, code_cache)
@assert world <= Base.get_world_counter()
method_table = get_method_table_view(world, method_table)
inf_cache = Vector{CC.InferenceResult}()
return ExternalInterpreter(world, method_table, code_cache, inf_cache)
end
CC.InferenceParams(interp::ExternalInterpreter) = CC.InferenceParams()
CC.OptimizationParams(interp::ExternalInterpreter) = CC.OptimizationParams()
CC.get_world_counter(interp::ExternalInterpreter) = interp.world
CC.get_inference_cache(interp::ExternalInterpreter) = interp.inf_cache
CC.code_cache(interp::ExternalInterpreter) = CC.WorldView(interp.code_cache, interp.world)
# No need to do any locking since we're not putting our results into the runtime cache
CC.lock_mi_inference(interp::ExternalInterpreter, mi::MethodInstance) = nothing
CC.unlock_mi_inference(interp::ExternalInterpreter, mi::MethodInstance) = nothing
function CC.add_remark!(interp::ExternalInterpreter, sv::CC.InferenceState, msg)
@debug "Inference remark during External compilation of $(sv.linfo): $msg"
end
CC.may_optimize(interp::ExternalInterpreter) = true
CC.may_compress(interp::ExternalInterpreter) = true
CC.may_discard_trees(interp::ExternalInterpreter) = true
CC.verbose_stmt_info(interp::ExternalInterpreter) = false
CC.method_table(interp::ExternalInterpreter) = interp.method_table
# main
Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE)
Base.Experimental.@overlay(GLOBAL_METHOD_TABLE,
@noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} = return
)
f(x) = Float32(x, RoundDown)
function main()
println("Native:")
display(Base.code_ircode(f, Tuple{Irrational{:π}}))
println()
println("External:")
interp = ExternalInterpreter(; method_table=GLOBAL_METHOD_TABLE, code_cache=CodeCache())
display(Base.code_ircode(f, Tuple{Irrational{:π}}; interp))
return
end
isinteractive() || main()
Native:
1-element Vector{Any}:
118 1 ─ return 3.1415925f0 │
=> Float32
External:
1-element Vector{Any}:
118 1 ─ %1 = Base.setprecision::typeof(setprecision) │╻ Type
│ %2 = Base.BigFloat::Type{BigFloat} ││
│ %3 = invoke Base.MPFR.:(var"#setprecision#25")($(QuoteNode(Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}()))::Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, %1::typeof(setprecision), Base.var"#980#981"{Float32, Irrational{:π}, RoundingMode{:Down}}(π, RoundingMode{:Down}())::Base.var"#980#981"{Float32, Irrational{:π}, RoundingMode{:Down}}, %2::Type{BigFloat}, 256::Int64)::Float32
└── return %3 │
=> Float32
Another example is #48097, which we "fixed" by avoiding the calls to Core.throw_inexacterror
in #48116. That kind of solution obviously doesn't scale.
To properly solve this, we probably have to define precise semantics of method overlays, and how they affect optimization.
For example, we could offer the following possibilities:
:taint
(the default, and current behavior): concrete evaluation of a call is disabled if it calls this overlay method:equivalent
: the overlay method is functionally equivalent to the original method, so the compiler can use information from (i.e., concretely evaluate) the original method to optimize the call:executable
: the overlay method is safe to execute on the host, so concrete evaluation can use it directly
:equivalent
semantics are required for most GPU overlays (e.g., when replacing openlibm functions with NVIDIA's GPU-only math library), but are slightly dangerous as I can imagine it could be tricky to guarantee that the overlay is actually functionally identical. That's why, when possible, I would think the :executable
semantic to be a better option.
Note that I'm writing the above from the perspective of the GPUCompiler.jl use case, without much experience with the optimizer/irinterp/effects, so I'm probably missing some important details.
#51080 by @aviatesk implements something similar to this, basically making it possible to mark overlay methods as non-overlay, but as @Keno mentions there we probably need to be slightly more precise.
Tentatively putting this on the milestone, as we're running into this more often now that the optimizer is relying on effects more.