-
Notifications
You must be signed in to change notification settings - Fork 55
Allow for generic MethodTableView #494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #494 +/- ##
==========================================
+ Coverage 73.54% 73.71% +0.17%
==========================================
Files 24 24
Lines 3485 3519 +34
==========================================
+ Hits 2563 2594 +31
- Misses 922 925 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
1d233d7
to
e18b7c2
Compare
For now just rebased and included the code from Shenanigans.jl |
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/interface.jl b/src/interface.jl
index 7553dbb..cb0c05f 100644
--- a/src/interface.jl
+++ b/src/interface.jl
@@ -231,12 +231,14 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
# provide a specific interpreter to use.
if VERSION >= v"1.11.0-DEV.1552"
get_interpreter(@nospecialize(job::CompilerJob)) =
- GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)),
+ GPUInterpreter(
+ job.world; method_table_view = maybe_cached(method_table_view(job)),
token=ci_cache_token(job), inf_params=inference_params(job),
opt_params=optimization_params(job))
else
get_interpreter(@nospecialize(job::CompilerJob)) =
- GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)),
+ GPUInterpreter(
+ job.world; method_table_view = maybe_cached(method_table_view(job)),
code_cache=ci_cache(job), inf_params=inference_params(job),
opt_params=optimization_params(job))
end
diff --git a/src/jlgen.jl b/src/jlgen.jl
index 1bff8a0..4982c64 100644
--- a/src/jlgen.jl
+++ b/src/jlgen.jl
@@ -300,7 +300,7 @@ Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE)
# Implements a priority lookup for method tables, where the first match in the stack get's returned.
# An alternative to this would be to use a "Union" where we would query the parent method table and
# do a most-specific match.
-struct StackedMethodTable{MTV<:CC.MethodTableView} <: CC.MethodTableView
+struct StackedMethodTable{MTV <: CC.MethodTableView} <: CC.MethodTableView
world::UInt
mt::Core.MethodTable
parent::MTV
@@ -313,7 +313,7 @@ CC.isoverlayed(::StackedMethodTable) = true
@static if VERSION >= v"1.11.0-DEV.363"
# https://github.com/JuliaLang/julia/pull/51078
# same API as before but without returning isoverlayed flag
- function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1)
+ function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int = -1)
result = CC._findall(sig, table.mt, table.world, limit)
result === nothing && return nothing # to many matches
nr = CC.length(result)
@@ -321,17 +321,19 @@ CC.isoverlayed(::StackedMethodTable) = true
# no need to fall back to the parent method view
return result
end
-
+
parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodLookupResult}
parent_result === nothing && return nothing #too many matches
-
+
# merge the parent match results with the internal method table
return CC.MethodLookupResult(
CC.vcat(result.matches, parent_result.matches),
CC.WorldRange(
CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
- CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)),
- result.ambig | parent_result.ambig)
+ CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)
+ ),
+ result.ambig | parent_result.ambig
+ )
end
function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable)
@@ -342,11 +344,12 @@ CC.isoverlayed(::StackedMethodTable) = true
parent_match,
CC.WorldRange(
max(valid_worlds.min_world, parent_valid_worlds.min_world),
- min(valid_worlds.max_world, parent_valid_worlds.max_world))
- )
+ min(valid_worlds.max_world, parent_valid_worlds.max_world)
+ ),
+ )
end
else
- function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1)
+ function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int = -1)
result = CC._findall(sig, table.mt, table.world, limit)
result === nothing && return nothing # to many matches
nr = CC.length(result)
@@ -354,22 +357,25 @@ else
# no need to fall back to the parent method view
return CC.MethodMatchResult(result, true)
end
-
+
parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodMatchResult}
parent_result === nothing && return nothing #too many matches
-
+
overlayed = parent_result.overlayed | !CC.isempty(result)
parent_result = parent_result.matches::CC.MethodLookupResult
-
+
# merge the parent match results with the internal method table
return CC.MethodMatchResult(
- CC.MethodLookupResult(
- CC.vcat(result.matches, parent_result.matches),
- CC.WorldRange(
- CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
- CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)),
- result.ambig | parent_result.ambig),
- overlayed)
+ CC.MethodLookupResult(
+ CC.vcat(result.matches, parent_result.matches),
+ CC.WorldRange(
+ CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
+ CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)
+ ),
+ result.ambig | parent_result.ambig
+ ),
+ overlayed
+ )
end
function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable)
@@ -380,8 +386,10 @@ else
parent_match,
CC.WorldRange(
max(valid_worlds.min_world, parent_valid_worlds.min_world),
- min(valid_worlds.max_world, parent_valid_worlds.max_world)),
- overlayed)
+ min(valid_worlds.max_world, parent_valid_worlds.max_world)
+ ),
+ overlayed,
+ )
end
end
@@ -404,7 +412,7 @@ end
get_method_table_view(world::UInt, mt::CC.MethodTable) = CC.OverlayMethodTable(world, mt)
-struct GPUInterpreter{MTV<:CC.MethodTableView} <: CC.AbstractInterpreter
+struct GPUInterpreter{MTV <: CC.MethodTableView} <: CC.AbstractInterpreter
world::UInt
method_table_view::MTV
@@ -421,7 +429,7 @@ end
@static if HAS_INTEGRATED_CACHE
function GPUInterpreter(world::UInt=Base.get_world_counter();
- method_table_view::CC.MethodTableView,
+ method_table_view::CC.MethodTableView,
token::Any,
inf_params::CC.InferenceParams,
opt_params::CC.OptimizationParams)
@@ -429,19 +437,21 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
inf_cache = Vector{CC.InferenceResult}()
- return GPUInterpreter(world, method_table_view,
+ return GPUInterpreter(
+ world, method_table_view,
token, inf_cache,
inf_params, opt_params)
end
function GPUInterpreter(interp::GPUInterpreter;
world::UInt=interp.world,
- method_table_view::Core.MethodTable=interp.method_table_view,
+ method_table_view::Core.MethodTable = interp.method_table_view,
token::Any=interp.token,
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
inf_params::CC.InferenceParams=interp.inf_params,
opt_params::CC.OptimizationParams=interp.opt_params)
- return GPUInterpreter(world, method_table_view,
+ return GPUInterpreter(
+ world, method_table_view,
token, inf_cache,
inf_params, opt_params)
end
@@ -449,7 +459,7 @@ end
else
function GPUInterpreter(world::UInt=Base.get_world_counter();
- method_table_view::CC.MethodTableView,
+ method_table_view::CC.MethodTableView,
code_cache::CodeCache,
inf_params::CC.InferenceParams,
opt_params::CC.OptimizationParams)
@@ -457,19 +467,21 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
inf_cache = Vector{CC.InferenceResult}()
- return GPUInterpreter(world, method_table_view,
+ return GPUInterpreter(
+ world, method_table_view,
code_cache, inf_cache,
inf_params, opt_params)
end
function GPUInterpreter(interp::GPUInterpreter;
world::UInt=interp.world,
- method_table_view::CC.MethodTableView=interp.method_table_view,
+ method_table_view::CC.MethodTableView = interp.method_table_view,
code_cache::CodeCache=interp.code_cache,
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
inf_params::CC.InferenceParams=interp.inf_params,
opt_params::CC.OptimizationParams=interp.opt_params)
- return GPUInterpreter(world, method_table_view,
+ return GPUInterpreter(
+ world, method_table_view,
code_cache, inf_cache,
inf_params, opt_params)
end
diff --git a/test/utils.jl b/test/utils.jl
index f0de138..f70cd94 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -111,14 +111,14 @@ DoubleStackedMT() = StackedMethodTable(Base.get_world_counter(), OtherMT, LayerM
@test isoverlayed(DoubleStackedMT()) == true
end
- o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT())
- s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT())
+ o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT())
+ s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT())
ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT())
@test s_sin == o_sin
@test ss_sin == o_sin
- o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
- s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
+ o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
+ s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT())
if VERSION >= v"1.11.0-DEV.363"
@test o_sin.matches == s_sin.matches
@@ -150,8 +150,8 @@ next_world = Base.get_world_counter()
@test worlds.min_world > prev_world
@test worlds.max_world == typemax(typeof(next_world))
- o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
- s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
+ o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
+ s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT())
if VERSION >= v"1.11.0-DEV.363"
@test o_sin.matches == s_sin.matches |
@aviatesk can you take a look if this makes sense? The goal here is to have a stack of three method tables, the first layer superseding the second, the second superseding the third. The third is Julia's internal method table. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From an implementation perspective, this seems fine.
One point to note is that in the current overlay method table mehcanism, simply looking at .fully_covers
cannot reproduce the specialization behavior that occurs in normal dispatch.
Consider the following example:
Base.Experimental.@MethodTable mt
func(a::Number) = :number
func(a::Integer) = :integer
Base.Experimental.@overlay mt func(a::Number) = :number_overlay
In this case, the correct specialization for Tuple{typeof(func), Int}
should naturally be dispatch to func(::Integer)
:
julia> Base.Compiler.findall(Tuple{typeof(func),Int}, Base.Compiler.InternalMethodTable(Base.get_world_counter()))
Compiler.MethodLookupResult(Any[Core.MethodMatch(Tuple{typeof(func), Int64}, svec(), func(a::Integer) @ Main REPL[14]:1, true)], Compiler.WorldRange(0x00000000000097cd, 0xffffffffffffffff), false)
However, when using overlay method tables, this rule is not preserved, and something like "local specialization within the overlay method table" takes in place:
julia> result = Base.Compiler.findall(Tuple{typeof(func),Int}, Base.Compiler.OverlayMethodTable(Base.get_world_counter(), mt))
Compiler.MethodLookupResult(Any[Core.MethodMatch(Tuple{typeof(func), Int64}, svec(), func(a::Number) @ Main REPL[15]:1, true)], Compiler.WorldRange(0x00000000000097ce, 0xffffffffffffffff), false)
julia> result.matches[1].fully_covers
true
This issue likely applies to this PR as well. It might be worth being careful about whether there are any problems regarding this point. This is more of an implementation issue with Base.Compiler.findall(type, ::OverlayMethodTable)
rather than a problem on the side that uses OverlayMethodTable
though.
For our uses that is fine, but yes for some of the potential use-cases that might be an issue. |
Currently, SPIRVIntrinsics, OpenCL, oneAPI, and POCL all share one method-table. We would like to disambiguate package specific overlays from the actual common intrinsic definitions.