Skip to content

Commit 1fc9d54

Browse files
committed
Widen diagonal var during Type unwrapping in instanceof_tfunc
1 parent 72cd63c commit 1fc9d54

File tree

5 files changed

+241
-4
lines changed

5 files changed

+241
-4
lines changed

base/compiler/tfuncs.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,25 +95,31 @@ add_tfunc(throw, 1, 1, @nospecs((𝕃::AbstractLattice, x)->Bottom), 0)
9595
# if isexact is false, the actual runtime type may (will) be a subtype of t
9696
# if isconcrete is true, the actual runtime type is definitely concrete (unreachable if not valid as a typeof)
9797
# if istype is true, the actual runtime value will definitely be a type (e.g. this is false for Union{Type{Int}, Int})
98-
function instanceof_tfunc(@nospecialize(t), astag::Bool=false)
98+
function instanceof_tfunc(@nospecialize(t), astag::Bool=false, @nospecialize(troot) = t)
9999
if isa(t, Const)
100100
if isa(t.val, Type) && valid_as_lattice(t.val, astag)
101101
return t.val, true, isconcretetype(t.val), true
102102
end
103103
return Bottom, true, false, false # runtime throws on non-Type
104104
end
105105
t = widenconst(t)
106+
troot = widenconst(troot)
106107
if t === Bottom
107108
return Bottom, true, true, false # runtime unreachable
108109
elseif t === typeof(Bottom) || !hasintersect(t, Type)
109110
return Bottom, true, false, false # literal Bottom or non-Type
110111
elseif isType(t)
111112
tp = t.parameters[1]
112113
valid_as_lattice(tp, astag) || return Bottom, true, false, false # runtime unreachable / throws on non-Type
114+
if troot isa UnionAll
115+
# Free `TypeVar`s inside `Type` has violated the "diagonal" rule.
116+
# Widen them before `UnionAll` rewraping to relax concrete constraint.
117+
tp = widen_diagonal(tp, troot)
118+
end
113119
return tp, !has_free_typevars(tp), isconcretetype(tp), true
114120
elseif isa(t, UnionAll)
115121
t′ = unwrap_unionall(t)
116-
t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, astag)
122+
t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, astag, rewrap_unionall(t, troot))
117123
tr = rewrap_unionall(t′′, t)
118124
if t′′ isa DataType && t′′.name !== Tuple.name && !has_free_typevars(tr)
119125
# a real instance must be within the declared bounds of the type,
@@ -128,8 +134,8 @@ function instanceof_tfunc(@nospecialize(t), astag::Bool=false)
128134
end
129135
return tr, isexact, isconcrete, istype
130136
elseif isa(t, Union)
131-
ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, astag)
132-
tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, astag)
137+
ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, astag, troot)
138+
tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, astag, troot)
133139
isconcrete = isconcrete_a && isconcrete_b
134140
istype = istype_a && istype_b
135141
# most users already handle the Union case, so here we assume that

base/essentials.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,11 @@ function rename_unionall(@nospecialize(u))
459459
return UnionAll(nv, body{nv})
460460
end
461461

462+
# remove concrete constraint on diagonal TypeVar if it comes from troot
463+
function widen_diagonal(@nospecialize(t), troot::UnionAll)
464+
body = ccall(:jl_widen_diagonal, Any, (Any, Any), t, troot)
465+
end
466+
462467
function isvarargtype(@nospecialize(t))
463468
return isa(t, Core.TypeofVararg)
464469
end

src/subtype.c

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4304,6 +4304,211 @@ int jl_subtype_matching(jl_value_t *a, jl_value_t *b, jl_svec_t **penv)
43044304
return sub;
43054305
}
43064306

4307+
// type utils
4308+
static void check_diagonal(jl_value_t *t, jl_varbinding_t *troot, int param)
4309+
{
4310+
if (jl_is_uniontype(t)) {
4311+
int i, len = 0;
4312+
jl_varbinding_t *v;
4313+
for (v = troot; v != NULL; v = v->prev)
4314+
len++;
4315+
int8_t *occurs = (int8_t *)alloca(len);
4316+
for (v = troot, i = 0; v != NULL; v = v->prev, i++)
4317+
occurs[i] = v->occurs_inv | (v->occurs_cov << 2);
4318+
check_diagonal(((jl_uniontype_t *)t)->a, troot, param);
4319+
for (v = troot, i = 0; v != NULL; v = v->prev, i++) {
4320+
int8_t occurs_inv = occurs[i] & 3;
4321+
int8_t occurs_cov = occurs[i] >> 2;
4322+
occurs[i] = v->occurs_inv | (v->occurs_cov << 2);
4323+
v->occurs_inv = occurs_inv;
4324+
v->occurs_cov = occurs_cov;
4325+
}
4326+
check_diagonal(((jl_uniontype_t *)t)->b, troot, param);
4327+
for (v = troot, i = 0; v != NULL; v = v->prev, i++) {
4328+
if (v->occurs_inv < (occurs[i] & 3))
4329+
v->occurs_inv = occurs[i] & 3;
4330+
if (v->occurs_cov < (occurs[i] >> 2))
4331+
v->occurs_cov = occurs[i] >> 2;
4332+
}
4333+
}
4334+
else if (jl_is_unionall(t)) {
4335+
assert(troot != NULL);
4336+
jl_varbinding_t *v1 = troot, *v2 = troot->prev;
4337+
while (v2 != NULL) {
4338+
if (v2->var == ((jl_unionall_t *)t)->var) {
4339+
v1->prev = v2->prev;
4340+
break;
4341+
}
4342+
v1 = v2;
4343+
v2 = v2->prev;
4344+
}
4345+
check_diagonal(((jl_unionall_t *)t)->body, troot, param);
4346+
v1->prev = v2;
4347+
}
4348+
else if (jl_is_datatype(t)) {
4349+
int nparam = jl_is_tuple_type(t) ? 1 : 2;
4350+
if (nparam < param) nparam = param;
4351+
for (size_t i = 0; i < jl_nparams(t); i++) {
4352+
check_diagonal(jl_tparam(t, i), troot, nparam);
4353+
}
4354+
}
4355+
else if (jl_is_vararg(t)) {
4356+
jl_value_t *T = jl_unwrap_vararg(t);
4357+
jl_value_t *N = jl_unwrap_vararg_num(t);
4358+
int n = (N && jl_is_long(N)) ? jl_unbox_long(N) : 2;
4359+
if (T && n > 0) check_diagonal(T, troot, param);
4360+
if (T && n > 1) check_diagonal(T, troot, param);
4361+
if (N) check_diagonal(N, troot, 2);
4362+
}
4363+
else if (jl_is_typevar(t)) {
4364+
jl_varbinding_t *v = troot;
4365+
for (; v != NULL; v = v->prev) {
4366+
if (v->var == (jl_tvar_t *)t) {
4367+
if (param == 1 && v->occurs_cov < 2) v->occurs_cov++;
4368+
if (param == 2 && v->occurs_inv < 2) v->occurs_inv++;
4369+
break;
4370+
}
4371+
}
4372+
if (v == NULL)
4373+
check_diagonal(((jl_tvar_t *)t)->ub, troot, 0);
4374+
}
4375+
}
4376+
4377+
static jl_value_t *insert_nondiagonal(jl_value_t *type, jl_varbinding_t *troot, int widen2ub)
4378+
{
4379+
// we must replace each covariant occurrence of newvar with a different newvar2<:newvar (diagonal rule)
4380+
if (jl_is_typevar(type)) {
4381+
jl_varbinding_t *v = troot;
4382+
for (; v != NULL; v = v->prev) {
4383+
if (v->concrete && v->var == (jl_tvar_t *)type)
4384+
break;
4385+
}
4386+
if (v != NULL) {
4387+
if (widen2ub) {
4388+
type = ((jl_tvar_t *)type)->ub;
4389+
}
4390+
else {
4391+
if (v->innervars == NULL)
4392+
v->innervars = jl_alloc_array_1d(jl_array_any_type, 0);
4393+
jl_value_t *newvar = NULL, *lb = v->var->lb, *ub = (jl_value_t *)v->var;
4394+
jl_array_t *innervars = v->innervars;
4395+
JL_GC_PUSH4(&newvar, &lb, &ub, &innervars);
4396+
newvar = (jl_value_t *)jl_new_typevar(v->var->name, lb, ub);
4397+
jl_array_ptr_1d_push(innervars, newvar);
4398+
JL_GC_POP();
4399+
type = newvar;
4400+
}
4401+
}
4402+
}
4403+
else if (jl_is_unionall(type)) {
4404+
jl_value_t *body = ((jl_unionall_t*)type)->body;
4405+
jl_tvar_t *var = ((jl_unionall_t*)type)->var;
4406+
jl_varbinding_t *v = troot;
4407+
for (; v != NULL; v = v->prev) {
4408+
if (v->var == (jl_tvar_t *)var)
4409+
break;
4410+
}
4411+
if (v == NULL) {
4412+
jl_value_t *newbody = insert_nondiagonal(body, troot, widen2ub);
4413+
jl_value_t *newvar = NULL;
4414+
JL_GC_PUSH2(&newbody, &newvar);
4415+
if (body == newbody || jl_has_typevar(newbody, var)) {
4416+
if (body != newbody)
4417+
newbody = jl_new_struct(jl_unionall_type, var, newbody);
4418+
// n.b. we do not widen lb, since that would be the wrong direction
4419+
newvar = insert_nondiagonal(var->ub, troot, widen2ub);
4420+
if (newvar != var->ub) {
4421+
newvar = (jl_value_t*)jl_new_typevar(var->name, var->lb, newvar);
4422+
newbody = jl_apply_type1(newbody, newvar);
4423+
newbody = jl_type_unionall((jl_tvar_t*)newvar, newbody);
4424+
}
4425+
}
4426+
type = newbody;
4427+
JL_GC_POP();
4428+
}
4429+
}
4430+
else if (jl_is_uniontype(type)) {
4431+
jl_value_t *a = ((jl_uniontype_t*)type)->a;
4432+
jl_value_t *b = ((jl_uniontype_t*)type)->b;
4433+
jl_value_t *newa = NULL;
4434+
jl_value_t *newb = NULL;
4435+
JL_GC_PUSH2(&newa, &newb);
4436+
newa = insert_nondiagonal(a, troot, widen2ub);
4437+
newb = insert_nondiagonal(b, troot, widen2ub);
4438+
if (newa != a || newb != b)
4439+
type = jl_new_struct(jl_uniontype_type, newa, newb);
4440+
JL_GC_POP();
4441+
}
4442+
else if (jl_is_vararg(type)) {
4443+
// As for Vararg we'd better widen it's var to ub as otherwise they are still diagonal
4444+
jl_value_t *t = jl_unwrap_vararg(type);
4445+
jl_value_t *n = jl_unwrap_vararg_num(type);
4446+
widen2ub = !(n && jl_is_long(n)) || jl_unbox_long(n) > 1;
4447+
jl_value_t *newt;
4448+
JL_GC_PUSH1(&newt);
4449+
newt = insert_nondiagonal(t, troot, widen2ub);
4450+
if (t != newt)
4451+
type = jl_new_struct(jl_vararg_type, newt, n);
4452+
JL_GC_POP();
4453+
}
4454+
else if (jl_is_datatype(type)) {
4455+
if (jl_is_tuple_type(type)) {
4456+
jl_svec_t *newparams = NULL;
4457+
JL_GC_PUSH1(&newparams);
4458+
for (size_t i = 0; i < jl_nparams(type); i++) {
4459+
jl_value_t *elt = jl_tparam(type, i);
4460+
jl_value_t *newelt = insert_nondiagonal(elt, troot, widen2ub);
4461+
if (elt != newelt) {
4462+
if (!newparams) {
4463+
newparams = (jl_svec_t*)newelt; // temporary root
4464+
newparams = jl_svec_copy(((jl_datatype_t*)type)->parameters);
4465+
}
4466+
jl_svecset(newparams, i, newelt);
4467+
}
4468+
}
4469+
if (newparams)
4470+
type = (jl_value_t*)jl_apply_tuple_type(newparams, 1);
4471+
JL_GC_POP();
4472+
}
4473+
}
4474+
return type;
4475+
}
4476+
4477+
static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) {
4478+
check_diagonal(t, troot, 0);
4479+
int any_concrete = 0;
4480+
for (jl_varbinding_t *v = troot; v != NULL; v = v->prev) {
4481+
v->concrete = v->occurs_cov > 1 && v->occurs_inv == 0;
4482+
any_concrete |= v->concrete;
4483+
}
4484+
if (!any_concrete)
4485+
return t; // no diagonal
4486+
return insert_nondiagonal(t, troot, 0);
4487+
}
4488+
4489+
static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot)
4490+
{
4491+
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
4492+
jl_value_t *nt;
4493+
JL_GC_PUSH2(&vb.innervars, &nt);
4494+
if (jl_is_unionall(u->body))
4495+
nt = widen_diagonal(t, (jl_unionall_t *)u->body, &vb);
4496+
else
4497+
nt = _widen_diagonal(t, &vb);
4498+
if (vb.innervars != NULL) {
4499+
for (size_t i = 0; i < jl_array_nrows(vb.innervars); i++) {
4500+
jl_tvar_t *var = (jl_tvar_t*)jl_array_ptr_ref(vb.innervars, i);
4501+
nt = jl_type_unionall(var, nt);
4502+
}
4503+
}
4504+
JL_GC_POP();
4505+
return nt;
4506+
}
4507+
4508+
JL_DLLEXPORT jl_value_t *jl_widen_diagonal(jl_value_t *t, jl_unionall_t *ua)
4509+
{
4510+
return widen_diagonal(t, ua, NULL);
4511+
}
43074512

43084513
// specificity comparison
43094514

test/compiler/inference.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5558,3 +5558,16 @@ function foo_typed_throw_metherr()
55585558
return 1
55595559
end
55605560
@test Base.return_types(foo_typed_throw_metherr) |> only === Float64
5561+
5562+
# Issue #52168
5563+
f52168(x, t::Type) = x::NTuple{2, Base.inferencebarrier(t)::Type}
5564+
@test f52168((1, 2.), Any) === (1, 2.)
5565+
5566+
# Issue #27031
5567+
let x = 1, _Any = Any
5568+
@noinline bar27031(tt::Tuple{T,T}, ::Type{Val{T}}) where {T} = notsame27031(tt)
5569+
@noinline notsame27031(tt::Tuple{T, T}) where {T} = error()
5570+
@noinline notsame27031(tt::Tuple{T, S}) where {T, S} = "OK"
5571+
foo27031() = bar27031((x, 1.0), Val{_Any})
5572+
@test foo27031() == "OK"
5573+
end

test/core.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8059,3 +8059,11 @@ check_globalref_lowering() = @insert_global
80598059
let src = code_lowered(check_globalref_lowering)[1]
80608060
@test length(src.code) == 2
80618061
end
8062+
8063+
# Test correctness of widen_diagonal
8064+
let widen_diagonal(x::UnionAll) = Base.rewrap_unionall(Base.widen_diagonal(Base.unwrap_unionall(x), x), x),
8065+
check_widen_diagonal(x, y) = !<:(x, y) && x <: widen_diagonal(y)
8066+
@test Tuple{Int,Float64} <: widen_diagonal(NTuple)
8067+
@test Tuple{Int,Float64} <: widen_diagonal(Tuple{T,T} where {T})
8068+
@test Union{Tuple{T}, Tuple{T,Int}} where {T} == widen_diagonal(Union{Tuple{T}, Tuple{T,Int}} where {T})
8069+
end

0 commit comments

Comments
 (0)