Skip to content

Commit 2a0eb70

Browse files
authored
Allow Module as type parameters (#47749)
The intended use case for this is generated functions that want to generate some reference to a module-specific generic function. The current solution is to duplicate the generated function into every module (probably using a package-provided macro) or to have some sort of registry system in the package providing the generated function. Both of these seem a bit ugly and I don't think there's any particularly good reason not to allow Modules to be type parameters. Admittedly, modules are not part of the scope contemplated by #33387 as they are mutable, but I think the mental model of modules is that they're immutable references to a namespace and what's actually mutable is the namespace itself (i.e. people wouldn't expect two modules that happen to have the same content be `==`). This makes me think it still fits the mental model.
1 parent 89671ae commit 2a0eb70

File tree

8 files changed

+57
-19
lines changed

8 files changed

+57
-19
lines changed

base/compiler/typeutils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,12 @@ function valid_as_lattice(@nospecialize(x))
132132
end
133133

134134
function valid_typeof_tparam(@nospecialize(t))
135-
if t === Symbol || isbitstype(t)
135+
if t === Symbol || t === Module || isbitstype(t)
136136
return true
137137
end
138138
isconcretetype(t) || return false
139139
if t <: NamedTuple
140-
t = t.parameters[2]
140+
t = t.parameters[2]::DataType
141141
end
142142
if t <: Tuple
143143
for p in t.parameters

src/builtins.c

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ static int egal_types(const jl_value_t *a, const jl_value_t *b, jl_typeenv_t *en
198198
return egal_types(vma->N, vmb->N, env, tvar_names);
199199
return !vma->N && !vmb->N;
200200
}
201-
if (dt == jl_symbol_type)
201+
if (dt == jl_symbol_type || dt == jl_module_type)
202202
return 0;
203203
assert(!dt->name->mutabl);
204204
return jl_egal__bits(a, b, dt);
@@ -414,6 +414,10 @@ static uintptr_t NOINLINE jl_object_id__cold(jl_datatype_t *dt, jl_value_t *v) J
414414
return memhash32_seed(jl_string_data(v), jl_string_len(v), 0xedc3b677);
415415
#endif
416416
}
417+
if (dt == jl_module_type) {
418+
jl_module_t *m = (jl_module_t*)v;
419+
return m->hash;
420+
}
417421
if (dt->name->mutabl)
418422
return inthash((uintptr_t)v);
419423
return immut_id_(dt, v, dt->hash);
@@ -1269,7 +1273,8 @@ static int is_nestable_type_param(jl_value_t *t)
12691273
size_t i, l = jl_nparams(t);
12701274
for (i = 0; i < l; i++) {
12711275
jl_value_t *pi = jl_tparam(t, i);
1272-
if (!(pi == (jl_value_t*)jl_symbol_type || jl_isbits(pi) || is_nestable_type_param(pi)))
1276+
if (!(pi == (jl_value_t*)jl_symbol_type || jl_isbits(pi) || is_nestable_type_param(pi) ||
1277+
jl_is_module(pi)))
12731278
return 0;
12741279
}
12751280
return 1;
@@ -1284,7 +1289,8 @@ int jl_valid_type_param(jl_value_t *v)
12841289
if (jl_is_vararg(v))
12851290
return 0;
12861291
// TODO: maybe more things
1287-
return jl_is_type(v) || jl_is_typevar(v) || jl_is_symbol(v) || jl_isbits(jl_typeof(v));
1292+
return jl_is_type(v) || jl_is_typevar(v) || jl_is_symbol(v) || jl_isbits(jl_typeof(v)) ||
1293+
jl_is_module(v);
12881294
}
12891295

12901296
JL_CALLABLE(jl_f_apply_type)
@@ -1896,7 +1902,7 @@ void jl_init_intrinsic_properties(void) JL_GC_DISABLED
18961902

18971903
void jl_init_intrinsic_functions(void) JL_GC_DISABLED
18981904
{
1899-
jl_module_t *inm = jl_new_module(jl_symbol("Intrinsics"));
1905+
jl_module_t *inm = jl_new_module(jl_symbol("Intrinsics"), NULL);
19001906
inm->parent = jl_core_module;
19011907
jl_set_const(jl_core_module, jl_symbol("Intrinsics"), (jl_value_t*)inm);
19021908
jl_mk_builtin_func(jl_intrinsic_type, "IntrinsicFunction", jl_f_intrinsic_call);

src/init.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,7 @@ static NOINLINE void _finish_julia_init(JL_IMAGE_SEARCH rel, jl_ptls_t ptls, jl_
821821
jl_init_serializer();
822822

823823
if (!jl_options.image_file) {
824-
jl_core_module = jl_new_module(jl_symbol("Core"));
824+
jl_core_module = jl_new_module(jl_symbol("Core"), NULL);
825825
jl_core_module->parent = jl_core_module;
826826
jl_type_typename->mt->module = jl_core_module;
827827
jl_top_module = jl_core_module;

src/julia.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ typedef struct _jl_module_t {
603603
uint8_t istopmod;
604604
int8_t max_methods;
605605
jl_mutex_t lock;
606+
intptr_t hash;
606607
} jl_module_t;
607608

608609
typedef struct {
@@ -1615,7 +1616,7 @@ extern JL_DLLEXPORT jl_module_t *jl_main_module JL_GLOBALLY_ROOTED;
16151616
extern JL_DLLEXPORT jl_module_t *jl_core_module JL_GLOBALLY_ROOTED;
16161617
extern JL_DLLEXPORT jl_module_t *jl_base_module JL_GLOBALLY_ROOTED;
16171618
extern JL_DLLEXPORT jl_module_t *jl_top_module JL_GLOBALLY_ROOTED;
1618-
JL_DLLEXPORT jl_module_t *jl_new_module(jl_sym_t *name);
1619+
JL_DLLEXPORT jl_module_t *jl_new_module(jl_sym_t *name, jl_module_t *parent);
16191620
JL_DLLEXPORT void jl_set_module_nospecialize(jl_module_t *self, int on);
16201621
JL_DLLEXPORT void jl_set_module_optlevel(jl_module_t *self, int lvl);
16211622
JL_DLLEXPORT int jl_get_module_optlevel(jl_module_t *m);

src/module.c

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
extern "C" {
1212
#endif
1313

14-
JL_DLLEXPORT jl_module_t *jl_new_module_(jl_sym_t *name, uint8_t default_names)
14+
JL_DLLEXPORT jl_module_t *jl_new_module_(jl_sym_t *name, jl_module_t *parent, uint8_t default_names)
1515
{
1616
jl_task_t *ct = jl_current_task;
1717
const jl_uuid_t uuid_zero = {0, 0};
1818
jl_module_t *m = (jl_module_t*)jl_gc_alloc(ct->ptls, sizeof(jl_module_t),
1919
jl_module_type);
2020
assert(jl_is_symbol(name));
2121
m->name = name;
22-
m->parent = NULL;
22+
m->parent = parent;
2323
m->istopmod = 0;
2424
m->uuid = uuid_zero;
2525
static unsigned int mcounter; // simple counter backup, in case hrtime is not incrementing
@@ -34,6 +34,8 @@ JL_DLLEXPORT jl_module_t *jl_new_module_(jl_sym_t *name, uint8_t default_names)
3434
m->compile = -1;
3535
m->infer = -1;
3636
m->max_methods = -1;
37+
m->hash = parent == NULL ? bitmix(name->hash, jl_module_type->hash) :
38+
bitmix(name->hash, parent->hash);
3739
JL_MUTEX_INIT(&m->lock);
3840
htable_new(&m->bindings, 0);
3941
arraylist_new(&m->usings, 0);
@@ -50,9 +52,9 @@ JL_DLLEXPORT jl_module_t *jl_new_module_(jl_sym_t *name, uint8_t default_names)
5052
return m;
5153
}
5254

53-
JL_DLLEXPORT jl_module_t *jl_new_module(jl_sym_t *name)
55+
JL_DLLEXPORT jl_module_t *jl_new_module(jl_sym_t *name, jl_module_t *parent)
5456
{
55-
return jl_new_module_(name, 1);
57+
return jl_new_module_(name, parent, 1);
5658
}
5759

5860
uint32_t jl_module_next_counter(jl_module_t *m)
@@ -63,10 +65,9 @@ uint32_t jl_module_next_counter(jl_module_t *m)
6365
JL_DLLEXPORT jl_value_t *jl_f_new_module(jl_sym_t *name, uint8_t std_imports, uint8_t default_names)
6466
{
6567
// TODO: should we prohibit this during incremental compilation?
66-
jl_module_t *m = jl_new_module_(name, default_names);
68+
// TODO: the parent module is a lie
69+
jl_module_t *m = jl_new_module_(name, jl_main_module, default_names);
6770
JL_GC_PUSH1(&m);
68-
m->parent = jl_main_module; // TODO: this is a lie
69-
jl_gc_wb(m, m->parent);
7071
if (std_imports)
7172
jl_add_standard_imports(m);
7273
JL_GC_POP();

src/toplevel.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ JL_DLLEXPORT void jl_add_standard_imports(jl_module_t *m)
4848
void jl_init_main_module(void)
4949
{
5050
assert(jl_main_module == NULL);
51-
jl_main_module = jl_new_module(jl_symbol("Main"));
51+
jl_main_module = jl_new_module(jl_symbol("Main"), NULL);
5252
jl_main_module->parent = jl_main_module;
5353
jl_set_const(jl_main_module, jl_symbol("Core"),
5454
(jl_value_t*)jl_core_module);
@@ -134,7 +134,8 @@ static jl_value_t *jl_eval_module_expr(jl_module_t *parent_module, jl_expr_t *ex
134134
jl_type_error("module", (jl_value_t*)jl_symbol_type, (jl_value_t*)name);
135135
}
136136

137-
jl_module_t *newm = jl_new_module(name);
137+
int is_parent__toplevel__ = jl_is__toplevel__mod(parent_module);
138+
jl_module_t *newm = jl_new_module(name, is_parent__toplevel__ ? NULL : parent_module);
138139
jl_value_t *form = (jl_value_t*)newm;
139140
JL_GC_PUSH1(&form);
140141
JL_LOCK(&jl_modules_mutex);
@@ -145,15 +146,14 @@ static jl_value_t *jl_eval_module_expr(jl_module_t *parent_module, jl_expr_t *ex
145146

146147
// copy parent environment info into submodule
147148
newm->uuid = parent_module->uuid;
148-
if (jl_is__toplevel__mod(parent_module)) {
149+
if (is_parent__toplevel__) {
149150
newm->parent = newm;
150151
jl_register_root_module(newm);
151152
if (jl_options.incremental) {
152153
jl_precompile_toplevel_module = newm;
153154
}
154155
}
155156
else {
156-
newm->parent = parent_module;
157157
jl_binding_t *b = jl_get_binding_wr(parent_module, name, 1);
158158
jl_declare_constant(b);
159159
jl_value_t *old = NULL;

test/core.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7873,3 +7873,16 @@ let # https://github.com/JuliaLang/julia/issues/46918
78737873
@test isempty(String(take!(stderr))) # make sure no error has happened
78747874
@test String(take!(stdout)) == "nothing IO IO"
78757875
end
7876+
7877+
# Modules allowed as type parameters and usable in generated functions
7878+
module ModTparamTest
7879+
foo_test_mod_tparam() = 1
7880+
end
7881+
foo_test_mod_tparam() = 2
7882+
7883+
struct ModTparamTestStruct{M}; end
7884+
@generated function ModTparamTestStruct{M}() where {M}
7885+
return :($(GlobalRef(M, :foo_test_mod_tparam))())
7886+
end
7887+
@test ModTparamTestStruct{@__MODULE__}() == 2
7888+
@test ModTparamTestStruct{ModTparamTest}() == 1

test/precompile.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,23 @@ end
15801580
@test which(f46778, Tuple{Any,DataType}).specializations[1].cache.invoke != C_NULL
15811581
end
15821582

1583+
1584+
precompile_test_harness("Module tparams") do load_path
1585+
write(joinpath(load_path, "ModuleTparams.jl"),
1586+
"""
1587+
module ModuleTparams
1588+
module TheTParam
1589+
end
1590+
1591+
struct ParamStruct{T}; end
1592+
const the_struct = ParamStruct{TheTParam}()
1593+
end
1594+
""")
1595+
Base.compilecache(Base.PkgId("ModuleTparams"))
1596+
(@eval (using ModuleTparams))
1597+
@test ModuleTparams.the_struct === Base.invokelatest(ModuleTparams.ParamStruct{ModuleTparams.TheTParam})
1598+
end
1599+
15831600
empty!(Base.DEPOT_PATH)
15841601
append!(Base.DEPOT_PATH, original_depot_path)
15851602
empty!(Base.LOAD_PATH)

0 commit comments

Comments
 (0)