Skip to content

External Method Tables #39697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ struct InternalMethodTable <: MethodTableView
world::UInt
end

"""
struct OverlayMethodTable <: MethodTableView

Overlays the internal method table such that specific queries can be redirected to an
external table, e.g., to override existing method.
"""
struct OverlayMethodTable <: MethodTableView
world::UInt
mt::Core.MethodTable
end

"""
struct CachedMethodTable <: MethodTableView

Expand All @@ -54,7 +65,26 @@ function findall(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable;
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, limit, table.world, false, _min_val, _max_val, _ambig)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end

function findall(@nospecialize(sig::Type{<:Tuple}), table::OverlayMethodTable; limit::Int=typemax(Int))
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, table.mt, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
elseif isempty(ms)
# fall back to the internal method table
_min_val[] = typemin(UInt)
_max_val[] = typemax(UInt)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
end
if ms === false
return missing
end
Expand Down
33 changes: 33 additions & 0 deletions base/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
module Experimental

using Base: Threads, sync_varname
using Base.Meta

"""
Const(A::Array)
Expand Down Expand Up @@ -255,4 +256,36 @@ end
# OpaqueClosure
include("opaque_closure.jl")

"""
Experimental.@overlay mt [function def]

Define a method and add it to the method table `mt` instead of to the global method table.
This can be used to implement a method override mechanism. Regular compilation will not
consider these methods, and you should customize the compilation flow to look in these
method tables (e.g., using [`Core.Compiler.OverlayMethodTable`](@ref)).

"""
macro overlay(mt, def)
def = macroexpand(__module__, def) # to expand @inline, @generated, etc
if !isexpr(def, [:function, :(=)]) || !isexpr(def.args[1], :call)
error("@overlay requires a function Expr")
end
def.args[1].args[1] = Expr(:overlay, mt, def.args[1].args[1])
esc(def)
end

"""
Experimental.@MethodTable(name)

Create a new MethodTable in the current module, bound to `name`. This method table can be
used with the [`Experimental.@overlay`](@ref) macro to define methods for a function without
adding them to the global method table.
"""
macro MethodTable(name)
isa(name, Symbol) || error("name must be a symbol")
esc(quote
const $name = ccall(:jl_new_method_table, Any, (Any, Any), $(quot(name)), $(__module__))
end)
end

end
17 changes: 10 additions & 7 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -878,13 +878,16 @@ function _methods(@nospecialize(f), @nospecialize(t), lim::Int, world::UInt)
end

function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt)
return _methods_by_ftype(t, lim, world, false, RefValue{UInt}(typemin(UInt)), RefValue{UInt}(typemax(UInt)), Ptr{Int32}(C_NULL))
return _methods_by_ftype(t, nothing, lim, world)
end
function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Array{UInt,1}, max::Array{UInt,1}, has_ambig::Array{Int32,1})
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt)
return _methods_by_ftype(t, mt, lim, world, false, RefValue{UInt}(typemin(UInt)), RefValue{UInt}(typemax(UInt)), Ptr{Int32}(C_NULL))
end
function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Ref{UInt}, max::Ref{UInt}, has_ambig::Ref{Int32})
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt, ambig::Bool, min::Array{UInt,1}, max::Array{UInt,1}, has_ambig::Array{Int32,1})
return ccall(:jl_matching_methods, Any, (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
end
function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt, ambig::Bool, min::Ref{UInt}, max::Ref{UInt}, has_ambig::Ref{Int32})
return ccall(:jl_matching_methods, Any, (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
end

function _method_by_ftype(args...)
Expand Down Expand Up @@ -952,7 +955,7 @@ function methods_including_ambiguous(@nospecialize(f), @nospecialize(t))
world = typemax(UInt)
min = RefValue{UInt}(typemin(UInt))
max = RefValue{UInt}(typemax(UInt))
ms = _methods_by_ftype(tt, -1, world, true, min, max, Ptr{Int32}(C_NULL))
ms = _methods_by_ftype(tt, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))
isa(ms, Bool) && return ms
return MethodList(Method[(m::Core.MethodMatch).method for m in ms], typeof(f).name.mt)
end
Expand Down Expand Up @@ -1508,7 +1511,7 @@ function isambiguous(m1::Method, m2::Method; ambiguous_bottom::Bool=false)
min = UInt[typemin(UInt)]
max = UInt[typemax(UInt)]
has_ambig = Int32[0]
ms = _methods_by_ftype(ti, -1, typemax(UInt), true, min, max, has_ambig)::Vector
ms = _methods_by_ftype(ti, nothing, -1, typemax(UInt), true, min, max, has_ambig)::Vector
has_ambig[] == 0 && return false
if !ambiguous_bottom
filter!(ms) do m::Core.MethodMatch
Expand Down
98 changes: 51 additions & 47 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ static const auto jlinvoke_func = new JuliaFunction{
static const auto jlmethod_func = new JuliaFunction{
"jl_method_def",
[](LLVMContext &C) { return FunctionType::get(T_prjlvalue,
{T_prjlvalue, T_prjlvalue, T_pjlvalue}, false); },
{T_prjlvalue, T_prjlvalue, T_prjlvalue, T_pjlvalue}, false); },
nullptr,
};
static const auto jlgenericfunction_func = new JuliaFunction{
Expand Down Expand Up @@ -4415,58 +4415,62 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval)
return emit_sparam(ctx, jl_unbox_long(args[0]) - 1);
}
else if (head == method_sym) {
jl_value_t *mn = args[0];
assert(jl_expr_nargs(ex) != 1 || jl_is_symbol(mn) || jl_is_slot(mn));
if (jl_expr_nargs(ex) == 1) {
jl_value_t *mn = args[0];
assert(jl_expr_nargs(ex) != 1 || jl_is_symbol(mn) || jl_is_slot(mn));

Value *bp = NULL, *name, *bp_owner = V_null;
jl_binding_t *bnd = NULL;
bool issym = jl_is_symbol(mn);
bool isglobalref = !issym && jl_is_globalref(mn);
jl_module_t *mod = ctx.module;
if (issym || isglobalref) {
if (isglobalref) {
mod = jl_globalref_mod(mn);
mn = (jl_value_t*)jl_globalref_name(mn);
}
JL_TRY {
if (jl_symbol_name((jl_sym_t*)mn)[0] == '@')
jl_errorf("macro definition not allowed inside a local scope");
name = literal_pointer_val(ctx, mn);
bnd = jl_get_binding_for_method_def(mod, (jl_sym_t*)mn);
}
JL_CATCH {
jl_value_t *e = jl_current_exception();
// errors. boo. root it somehow :(
bnd = jl_get_binding_wr(ctx.module, (jl_sym_t*)jl_gensym(), 1);
bnd->value = e;
bnd->constp = 1;
raise_exception(ctx, literal_pointer_val(ctx, e));
return ghostValue(jl_nothing_type);
}
bp = julia_binding_gv(ctx, bnd);
bp_owner = literal_pointer_val(ctx, (jl_value_t*)mod);
}
else if (jl_is_slot(mn) || jl_is_argument(mn)) {
int sl = jl_slot_number(mn)-1;
jl_varinfo_t &vi = ctx.slots[sl];
bp = vi.boxroot;
name = literal_pointer_val(ctx, (jl_value_t*)slot_symbol(ctx, sl));
}
if (bp) {
Value *mdargs[5] = { name, literal_pointer_val(ctx, (jl_value_t*)mod), bp,
bp_owner, literal_pointer_val(ctx, bnd) };
jl_cgval_t gf = mark_julia_type(
ctx,
ctx.builder.CreateCall(prepare_call(jlgenericfunction_func), makeArrayRef(mdargs)),
true,
jl_function_type);
if (jl_expr_nargs(ex) == 1)
Value *bp = NULL, *name, *bp_owner = V_null;
jl_binding_t *bnd = NULL;
bool issym = jl_is_symbol(mn);
bool isglobalref = !issym && jl_is_globalref(mn);
jl_module_t *mod = ctx.module;
if (issym || isglobalref) {
if (isglobalref) {
mod = jl_globalref_mod(mn);
mn = (jl_value_t*)jl_globalref_name(mn);
}
JL_TRY {
if (jl_symbol_name((jl_sym_t*)mn)[0] == '@')
jl_errorf("macro definition not allowed inside a local scope");
name = literal_pointer_val(ctx, mn);
bnd = jl_get_binding_for_method_def(mod, (jl_sym_t*)mn);
}
JL_CATCH {
jl_value_t *e = jl_current_exception();
// errors. boo. root it somehow :(
bnd = jl_get_binding_wr(ctx.module, (jl_sym_t*)jl_gensym(), 1);
bnd->value = e;
bnd->constp = 1;
raise_exception(ctx, literal_pointer_val(ctx, e));
return ghostValue(jl_nothing_type);
}
bp = julia_binding_gv(ctx, bnd);
bp_owner = literal_pointer_val(ctx, (jl_value_t*)mod);
}
else if (jl_is_slot(mn) || jl_is_argument(mn)) {
int sl = jl_slot_number(mn)-1;
jl_varinfo_t &vi = ctx.slots[sl];
bp = vi.boxroot;
name = literal_pointer_val(ctx, (jl_value_t*)slot_symbol(ctx, sl));
}
if (bp) {
Value *mdargs[5] = { name, literal_pointer_val(ctx, (jl_value_t*)mod), bp,
bp_owner, literal_pointer_val(ctx, bnd) };
jl_cgval_t gf = mark_julia_type(
ctx,
ctx.builder.CreateCall(prepare_call(jlgenericfunction_func), makeArrayRef(mdargs)),
true,
jl_function_type);
return gf;
}
emit_error(ctx, "method: invalid declaration");
return jl_cgval_t();
}
Value *a1 = boxed(ctx, emit_expr(ctx, args[1]));
Value *a2 = boxed(ctx, emit_expr(ctx, args[2]));
Value *mdargs[3] = {
Value *mdargs[4] = {
/*argdata*/a1,
ConstantPointerNull::get(cast<PointerType>(T_prjlvalue)),
/*code*/a2,
/*module*/literal_pointer_val(ctx, (jl_value_t*)ctx.module)
};
Expand Down
57 changes: 46 additions & 11 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,11 @@ static void jl_serialize_code_instance(jl_serializer_state *s, jl_code_instance_
jl_serialize_code_instance(s, codeinst->next, skip_partial_opaque);
}

enum METHOD_SERIALIZATION_MODE {
METHOD_INTERNAL = 1,
METHOD_EXTERNAL_MT = 2,
};

static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_literal) JL_GC_DISABLED
{
if (jl_serialize_generic(s, v)) {
Expand Down Expand Up @@ -627,18 +632,34 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
else if (jl_is_method(v)) {
write_uint8(s->s, TAG_METHOD);
jl_method_t *m = (jl_method_t*)v;
int internal = 1;
internal = m->is_for_opaque_closure || module_in_worklist(m->module);
if (!internal) {
int serialization_mode = 0;
if (m->is_for_opaque_closure || module_in_worklist(m->module))
serialization_mode |= METHOD_INTERNAL;
if (!(serialization_mode & METHOD_INTERNAL)) {
// flag this in the backref table as special
uintptr_t *bp = (uintptr_t*)ptrhash_bp(&backref_table, v);
assert(*bp != (uintptr_t)HT_NOTFOUND);
*bp |= 1;
}
jl_serialize_value(s, (jl_value_t*)m->sig);
jl_serialize_value(s, (jl_value_t*)m->module);
write_uint8(s->s, internal);
if (!internal)
if (m->external_mt != NULL) {
assert(jl_typeis(m->external_mt, jl_methtable_type));
jl_methtable_t *mt = (jl_methtable_t*)m->external_mt;
if (!module_in_worklist(mt->module)) {
serialization_mode |= METHOD_EXTERNAL_MT;
}
}
write_uint8(s->s, serialization_mode);
if (serialization_mode & METHOD_EXTERNAL_MT) {
// We reference this method table by module and binding
jl_methtable_t *mt = (jl_methtable_t*)m->external_mt;
jl_serialize_value(s, mt->module);
jl_serialize_value(s, mt->name);
} else {
jl_serialize_value(s, (jl_value_t*)m->external_mt);
}
if (!(serialization_mode & METHOD_INTERNAL))
return;
jl_serialize_value(s, m->specializations);
jl_serialize_value(s, m->speckeyset);
Expand Down Expand Up @@ -951,6 +972,10 @@ static void jl_collect_lambdas_from_mod(jl_array_t *s, jl_module_t *m) JL_GC_DIS
jl_collect_lambdas_from_mod(s, (jl_module_t*)b->value);
}
}
else if (jl_is_mtable(bv)) {
// a module containing an external method table
jl_collect_methtable_from_mod(s, (jl_methtable_t*)bv);
}
}
}
}
Expand Down Expand Up @@ -1014,7 +1039,7 @@ static void jl_collect_backedges(jl_array_t *s, jl_array_t *t)
size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
int ambig = 0;
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig);
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig);
if (matches == jl_false) {
valid = 0;
break;
Expand Down Expand Up @@ -1457,8 +1482,18 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
jl_gc_wb(m, m->sig);
m->module = (jl_module_t*)jl_deserialize_value(s, (jl_value_t**)&m->module);
jl_gc_wb(m, m->module);
int internal = read_uint8(s->s);
if (!internal) {
int serialization_mode = read_uint8(s->s);
if (serialization_mode & METHOD_EXTERNAL_MT) {
jl_module_t *mt_mod = (jl_module_t*)jl_deserialize_value(s, NULL);
jl_sym_t *mt_name = (jl_sym_t*)jl_deserialize_value(s, NULL);
m->external_mt = jl_get_global(mt_mod, mt_name);
jl_gc_wb(m, m->external_mt);
assert(jl_typeis(m->external_mt, jl_methtable_type));
} else {
m->external_mt = jl_deserialize_value(s, &m->external_mt);
jl_gc_wb(m, m->external_mt);
}
if (!(serialization_mode & METHOD_INTERNAL)) {
assert(loc != NULL && loc != HT_NOTFOUND);
arraylist_push(&flagref_list, loc);
arraylist_push(&flagref_list, (void*)pos);
Expand Down Expand Up @@ -1893,7 +1928,7 @@ static void jl_insert_methods(jl_array_t *list)
assert(!meth->is_for_opaque_closure);
jl_tupletype_t *simpletype = (jl_tupletype_t*)jl_array_ptr_ref(list, i + 1);
assert(jl_is_method(meth));
jl_methtable_t *mt = jl_method_table_for((jl_value_t*)meth->sig);
jl_methtable_t *mt = jl_method_get_table(meth);
assert((jl_value_t*)mt != jl_nothing);
jl_method_table_insert(mt, meth, simpletype);
}
Expand Down Expand Up @@ -1923,7 +1958,7 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids)
size_t max_valid = ~(size_t)0;
int ambig = 0;
// TODO: possibly need to included ambiguities too (for the optimizer correctness)?
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig);
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig);
if (matches == jl_false || jl_array_len(matches) != jl_array_len(expected)) {
valid = 0;
}
Expand Down Expand Up @@ -2461,7 +2496,7 @@ static jl_method_t *jl_recache_method(jl_method_t *m)
{
assert(!m->is_for_opaque_closure);
jl_datatype_t *sig = (jl_datatype_t*)m->sig;
jl_methtable_t *mt = jl_method_table_for((jl_value_t*)m->sig);
jl_methtable_t *mt = jl_method_get_table(m);
assert((jl_value_t*)mt != jl_nothing);
jl_set_typeof(m, (void*)(intptr_t)0x30); // invalidate the old value to help catch errors
jl_method_t *_new = jl_lookup_method(mt, sig, m->module->primary_world);
Expand Down
Loading