Skip to content

Allow for generic MethodTableView #494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 24, 2025
Merged

Allow for generic MethodTableView #494

merged 5 commits into from
Jun 24, 2025

Conversation

vchuravy
Copy link
Member

@vchuravy vchuravy commented Aug 16, 2023

Currently, SPIRVIntrinsics, OpenCL, oneAPI, and POCL all share one method-table. We would like to disambiguate package specific overlays from the actual common intrinsic definitions.

@codecov
Copy link

codecov bot commented Aug 16, 2023

Codecov Report

Attention: Patch coverage is 90.69767% with 4 lines in your changes missing coverage. Please review.

Project coverage is 73.71%. Comparing base (636d916) to head (bf8e2b8).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
src/jlgen.jl 90.47% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #494      +/-   ##
==========================================
+ Coverage   73.54%   73.71%   +0.17%     
==========================================
  Files          24       24              
  Lines        3485     3519      +34     
==========================================
+ Hits         2563     2594      +31     
- Misses        922      925       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Base automatically changed from vc/mt_cleanup to master August 17, 2023 18:10
@maleadt maleadt force-pushed the master branch 5 times, most recently from 1d233d7 to e18b7c2 Compare January 20, 2025 10:33
@vchuravy vchuravy marked this pull request as ready for review June 11, 2025 08:42
@vchuravy
Copy link
Member Author

For now just rebased and included the code from Shenanigans.jl

Copy link
Contributor

github-actions bot commented Jun 11, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/interface.jl b/src/interface.jl
index 7553dbb..cb0c05f 100644
--- a/src/interface.jl
+++ b/src/interface.jl
@@ -231,12 +231,14 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
 # provide a specific interpreter to use.
 if VERSION >= v"1.11.0-DEV.1552"
 get_interpreter(@nospecialize(job::CompilerJob)) =
-    GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)),
+        GPUInterpreter(
+        job.world; method_table_view = maybe_cached(method_table_view(job)),
                    token=ci_cache_token(job), inf_params=inference_params(job),
                    opt_params=optimization_params(job))
 else
 get_interpreter(@nospecialize(job::CompilerJob)) =
-    GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)),
+        GPUInterpreter(
+        job.world; method_table_view = maybe_cached(method_table_view(job)),
                    code_cache=ci_cache(job), inf_params=inference_params(job),
                    opt_params=optimization_params(job))
 end
diff --git a/src/jlgen.jl b/src/jlgen.jl
index 1bff8a0..4982c64 100644
--- a/src/jlgen.jl
+++ b/src/jlgen.jl
@@ -300,7 +300,7 @@ Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE)
 # Implements a priority lookup for method tables, where the first match in the stack get's returned.
 # An alternative to this would be to use a "Union" where we would query the parent method table and
 # do a most-specific match.
-struct StackedMethodTable{MTV<:CC.MethodTableView} <: CC.MethodTableView
+struct StackedMethodTable{MTV <: CC.MethodTableView} <: CC.MethodTableView
     world::UInt
     mt::Core.MethodTable
     parent::MTV
@@ -313,7 +313,7 @@ CC.isoverlayed(::StackedMethodTable) = true
 @static if VERSION >= v"1.11.0-DEV.363"
     # https://github.com/JuliaLang/julia/pull/51078
     # same API as before but without returning isoverlayed flag
-    function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1)
+    function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int = -1)
         result = CC._findall(sig, table.mt, table.world, limit)
         result === nothing && return nothing # to many matches
         nr = CC.length(result)
@@ -321,17 +321,19 @@ CC.isoverlayed(::StackedMethodTable) = true
             # no need to fall back to the parent method view
             return result
         end
-    
+
         parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodLookupResult}
         parent_result === nothing && return nothing #too many matches
-    
+
         # merge the parent match results with the internal method table
         return CC.MethodLookupResult(
             CC.vcat(result.matches, parent_result.matches),
             CC.WorldRange(
                 CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
-                CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)),
-            result.ambig | parent_result.ambig)
+                CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)
+            ),
+            result.ambig | parent_result.ambig
+        )
     end
 
     function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable)
@@ -342,11 +344,12 @@ CC.isoverlayed(::StackedMethodTable) = true
             parent_match,
             CC.WorldRange(
                 max(valid_worlds.min_world, parent_valid_worlds.min_world),
-                min(valid_worlds.max_world, parent_valid_worlds.max_world))
-            )
+                min(valid_worlds.max_world, parent_valid_worlds.max_world)
+            ),
+        )
     end
 else
-    function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1)
+    function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int = -1)
         result = CC._findall(sig, table.mt, table.world, limit)
         result === nothing && return nothing # to many matches
         nr = CC.length(result)
@@ -354,22 +357,25 @@ else
             # no need to fall back to the parent method view
             return CC.MethodMatchResult(result, true)
         end
-    
+
         parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodMatchResult}
         parent_result === nothing && return nothing #too many matches
-    
+
         overlayed = parent_result.overlayed | !CC.isempty(result)
         parent_result = parent_result.matches::CC.MethodLookupResult
-        
+
         # merge the parent match results with the internal method table
         return CC.MethodMatchResult(
-        CC.MethodLookupResult(
-            CC.vcat(result.matches, parent_result.matches),
-            CC.WorldRange(
-                CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
-                CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)),
-            result.ambig | parent_result.ambig),
-        overlayed)
+            CC.MethodLookupResult(
+                CC.vcat(result.matches, parent_result.matches),
+                CC.WorldRange(
+                    CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
+                    CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)
+                ),
+                result.ambig | parent_result.ambig
+            ),
+            overlayed
+        )
     end
 
     function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable)
@@ -380,8 +386,10 @@ else
             parent_match,
             CC.WorldRange(
                 max(valid_worlds.min_world, parent_valid_worlds.min_world),
-                min(valid_worlds.max_world, parent_valid_worlds.max_world)),
-            overlayed)
+                min(valid_worlds.max_world, parent_valid_worlds.max_world)
+            ),
+            overlayed,
+        )
     end
 end
 
@@ -404,7 +412,7 @@ end
 
 get_method_table_view(world::UInt, mt::CC.MethodTable) = CC.OverlayMethodTable(world, mt)
 
-struct GPUInterpreter{MTV<:CC.MethodTableView} <: CC.AbstractInterpreter
+struct GPUInterpreter{MTV <: CC.MethodTableView} <: CC.AbstractInterpreter
     world::UInt
     method_table_view::MTV
 
@@ -421,7 +429,7 @@ end
 
 @static if HAS_INTEGRATED_CACHE
 function GPUInterpreter(world::UInt=Base.get_world_counter();
-                        method_table_view::CC.MethodTableView,
+            method_table_view::CC.MethodTableView,
                         token::Any,
                         inf_params::CC.InferenceParams,
                         opt_params::CC.OptimizationParams)
@@ -429,19 +437,21 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
 
     inf_cache = Vector{CC.InferenceResult}()
 
-    return GPUInterpreter(world, method_table_view,
+        return GPUInterpreter(
+            world, method_table_view,
                           token, inf_cache,
                           inf_params, opt_params)
 end
 
 function GPUInterpreter(interp::GPUInterpreter;
                         world::UInt=interp.world,
-                        method_table_view::Core.MethodTable=interp.method_table_view,
+            method_table_view::Core.MethodTable = interp.method_table_view,
                         token::Any=interp.token,
                         inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
                         inf_params::CC.InferenceParams=interp.inf_params,
                         opt_params::CC.OptimizationParams=interp.opt_params)
-    return GPUInterpreter(world, method_table_view,
+        return GPUInterpreter(
+            world, method_table_view,
                           token, inf_cache,
                           inf_params, opt_params)
 end
@@ -449,7 +459,7 @@ end
 else
 
 function GPUInterpreter(world::UInt=Base.get_world_counter();
-                        method_table_view::CC.MethodTableView,
+            method_table_view::CC.MethodTableView,
                         code_cache::CodeCache,
                         inf_params::CC.InferenceParams,
                         opt_params::CC.OptimizationParams)
@@ -457,19 +467,21 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
 
     inf_cache = Vector{CC.InferenceResult}()
 
-    return GPUInterpreter(world, method_table_view,
+        return GPUInterpreter(
+            world, method_table_view,
                           code_cache, inf_cache,
                           inf_params, opt_params)
 end
 
 function GPUInterpreter(interp::GPUInterpreter;
                         world::UInt=interp.world,
-                        method_table_view::CC.MethodTableView=interp.method_table_view,
+            method_table_view::CC.MethodTableView = interp.method_table_view,
                         code_cache::CodeCache=interp.code_cache,
                         inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
                         inf_params::CC.InferenceParams=interp.inf_params,
                         opt_params::CC.OptimizationParams=interp.opt_params)
-    return GPUInterpreter(world, method_table_view,
+        return GPUInterpreter(
+            world, method_table_view,
                           code_cache, inf_cache,
                           inf_params, opt_params)
 end
diff --git a/test/utils.jl b/test/utils.jl
index f0de138..f70cd94 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -111,14 +111,14 @@ DoubleStackedMT() = StackedMethodTable(Base.get_world_counter(), OtherMT, LayerM
         @test isoverlayed(DoubleStackedMT()) == true
     end
 
-    o_sin  = findsup(Tuple{typeof(sin), Float64}, OverlayMT())
-    s_sin  = findsup(Tuple{typeof(sin), Float64}, StackedMT())
+    o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT())
+    s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT())
     ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT())
     @test s_sin == o_sin
     @test ss_sin == o_sin
 
-    o_sin  = findall(Tuple{typeof(sin), Float64}, OverlayMT())
-    s_sin  = findall(Tuple{typeof(sin), Float64}, StackedMT())
+    o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
+    s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
     ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT())
     if VERSION >= v"1.11.0-DEV.363"
         @test o_sin.matches == s_sin.matches
@@ -150,8 +150,8 @@ next_world = Base.get_world_counter()
     @test worlds.min_world > prev_world
     @test worlds.max_world == typemax(typeof(next_world))
 
-    o_sin  = findall(Tuple{typeof(sin), Float64}, OverlayMT())
-    s_sin  = findall(Tuple{typeof(sin), Float64}, StackedMT())
+    o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
+    s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
     ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT())
     if VERSION >= v"1.11.0-DEV.363"
         @test o_sin.matches == s_sin.matches

@vchuravy
Copy link
Member Author

@aviatesk can you take a look if this makes sense? The goal here is to have a stack of three method tables, the first layer superseding the second, the second superseding the third. The third is Julia's internal method table.

Copy link
Contributor

@aviatesk aviatesk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From an implementation perspective, this seems fine.
One point to note is that in the current overlay method table mehcanism, simply looking at .fully_covers cannot reproduce the specialization behavior that occurs in normal dispatch.
Consider the following example:

Base.Experimental.@MethodTable mt

func(a::Number) = :number

func(a::Integer) = :integer

Base.Experimental.@overlay mt func(a::Number) = :number_overlay

In this case, the correct specialization for Tuple{typeof(func), Int} should naturally be dispatch to func(::Integer):

julia> Base.Compiler.findall(Tuple{typeof(func),Int}, Base.Compiler.InternalMethodTable(Base.get_world_counter()))
Compiler.MethodLookupResult(Any[Core.MethodMatch(Tuple{typeof(func), Int64}, svec(), func(a::Integer) @ Main REPL[14]:1, true)], Compiler.WorldRange(0x00000000000097cd, 0xffffffffffffffff), false)

However, when using overlay method tables, this rule is not preserved, and something like "local specialization within the overlay method table" takes in place:

julia> result = Base.Compiler.findall(Tuple{typeof(func),Int}, Base.Compiler.OverlayMethodTable(Base.get_world_counter(), mt))
Compiler.MethodLookupResult(Any[Core.MethodMatch(Tuple{typeof(func), Int64}, svec(), func(a::Number) @ Main REPL[15]:1, true)], Compiler.WorldRange(0x00000000000097ce, 0xffffffffffffffff), false)

julia> result.matches[1].fully_covers
true

This issue likely applies to this PR as well. It might be worth being careful about whether there are any problems regarding this point. This is more of an implementation issue with Base.Compiler.findall(type, ::OverlayMethodTable) rather than a problem on the side that uses OverlayMethodTable though.

@vchuravy
Copy link
Member Author

This issue likely applies to this PR as well. It might be worth being careful about whether there are any problems regarding this point.

For our uses that is fine, but yes for some of the potential use-cases that might be an issue.

@vchuravy vchuravy merged commit 79e0f56 into master Jun 24, 2025
19 of 22 checks passed
@vchuravy vchuravy deleted the vc/mtv branch June 24, 2025 13:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants