Skip to content

Commit 2a2068d

Browse files
committed
Subtype: avoid false alarm caused by eager forall_exists_subtype. (#48441)
* Avoid earsing `Runion` within nested `forall_exists_subtype` If `Runion.more != 0` we‘d better not erase the local `Runion` as we need it if the subtyping fails after. This commit replaces `forall_exists_subtype` with a local version. It first tries `forall_exists_subtype` and estimates the "problem scale". If subtyping fails and the scale looks small then it switches to the slow path. TODO: At present, the "problem scale" only counts the number of checked `Lunion`s. But perhaps we need a more accurate result (e.g. sum of `Runion.depth`) * Change the reversed subtyping into a local check. Make sure we don't forget the bound in `env`. (And we can fuse `local_forall_exists_subtype`) * Optimization for non-union invariant parameter.
1 parent 2117018 commit 2a2068d

File tree

2 files changed

+91
-41
lines changed

2 files changed

+91
-41
lines changed

src/subtype.c

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv
555555
return u;
556556
}
557557

558-
static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);
558+
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow);
559559

560560
// subtype for variable bounds consistency check. needs its own forall/exists environment.
561561
static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
@@ -571,17 +571,7 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
571571
if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y))
572572
return 0;
573573
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
574-
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
575-
int sub;
576-
e->Lunions.used = e->Runions.used = 0;
577-
e->Runions.depth = 0;
578-
e->Runions.more = 0;
579-
e->Lunions.depth = 0;
580-
e->Lunions.more = 0;
581-
582-
sub = forall_exists_subtype(x, y, e, 0);
583-
584-
pop_unionstate(&e->Runions, &oldRunions);
574+
int sub = local_forall_exists_subtype(x, y, e, 0, 1);
585575
pop_unionstate(&e->Lunions, &oldLunions);
586576
return sub;
587577
}
@@ -1362,6 +1352,72 @@ static int is_definite_length_tuple_type(jl_value_t *x)
13621352
return k == JL_VARARG_NONE || k == JL_VARARG_INT;
13631353
}
13641354

1355+
static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore);
1356+
1357+
static int may_contain_union_decision(jl_value_t *x, jl_stenv_t *e, jl_typeenv_t *log) JL_NOTSAFEPOINT
1358+
{
1359+
if (x == NULL || x == (jl_value_t*)jl_any_type || x == jl_bottom_type)
1360+
return 0;
1361+
if (jl_is_unionall(x))
1362+
return may_contain_union_decision(((jl_unionall_t *)x)->body, e, log);
1363+
if (jl_is_datatype(x)) {
1364+
jl_datatype_t *xd = (jl_datatype_t *)x;
1365+
for (int i = 0; i < jl_nparams(xd); i++) {
1366+
jl_value_t *param = jl_tparam(xd, i);
1367+
if (jl_is_vararg(param))
1368+
param = jl_unwrap_vararg(param);
1369+
if (may_contain_union_decision(param, e, log))
1370+
return 1;
1371+
}
1372+
return 0;
1373+
}
1374+
if (!jl_is_typevar(x))
1375+
return 1;
1376+
jl_typeenv_t *t = log;
1377+
while (t != NULL) {
1378+
if (x == (jl_value_t *)t->var)
1379+
return 1;
1380+
t = t->prev;
1381+
}
1382+
jl_typeenv_t newlog = { (jl_tvar_t*)x, NULL, log };
1383+
jl_varbinding_t *xb = lookup(e, (jl_tvar_t *)x);
1384+
return may_contain_union_decision(xb ? xb->lb : ((jl_tvar_t *)x)->lb, e, &newlog) ||
1385+
may_contain_union_decision(xb ? xb->ub : ((jl_tvar_t *)x)->ub, e, &newlog);
1386+
}
1387+
1388+
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow)
1389+
{
1390+
int16_t oldRmore = e->Runions.more;
1391+
int sub;
1392+
if (may_contain_union_decision(y, e, NULL) && pick_union_decision(e, 1) == 0) {
1393+
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
1394+
e->Lunions.used = e->Runions.used = 0;
1395+
e->Lunions.depth = e->Runions.depth = 0;
1396+
e->Lunions.more = e->Runions.more = 0;
1397+
int count = 0, noRmore = 0;
1398+
sub = _forall_exists_subtype(x, y, e, param, &count, &noRmore);
1399+
pop_unionstate(&e->Runions, &oldRunions);
1400+
// we should not try the slow path if `forall_exists_subtype` has tested all cases;
1401+
// Once limit_slow == 1, also skip it if
1402+
// 1) `forall_exists_subtype` return false
1403+
// 2) the left `Union` looks big
1404+
if (noRmore || (limit_slow && (count > 3 || !sub)))
1405+
e->Runions.more = oldRmore;
1406+
}
1407+
else {
1408+
// slow path
1409+
e->Lunions.used = 0;
1410+
while (1) {
1411+
e->Lunions.more = 0;
1412+
e->Lunions.depth = 0;
1413+
sub = subtype(x, y, e, param);
1414+
if (!sub || !next_union_state(e, 0))
1415+
break;
1416+
}
1417+
}
1418+
return sub;
1419+
}
1420+
13651421
static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
13661422
{
13671423
if (obviously_egal(x, y)) return 1;
@@ -1380,33 +1436,13 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
13801436
}
13811437

13821438
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
1383-
e->Lunions.used = 0;
1384-
int sub;
13851439

1386-
if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) {
1387-
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
1388-
e->Runions.used = 0;
1389-
e->Runions.depth = 0;
1390-
e->Runions.more = 0;
1391-
e->Lunions.depth = 0;
1392-
e->Lunions.more = 0;
1393-
1394-
sub = forall_exists_subtype(x, y, e, 2);
1395-
1396-
pop_unionstate(&e->Runions, &oldRunions);
1397-
}
1398-
else {
1399-
while (1) {
1400-
e->Lunions.more = 0;
1401-
e->Lunions.depth = 0;
1402-
sub = subtype(x, y, e, 2);
1403-
if (!sub || !next_union_state(e, 0))
1404-
break;
1405-
}
1406-
}
1440+
int limit_slow = !jl_has_free_typevars(x) || !jl_has_free_typevars(y);
1441+
int sub = local_forall_exists_subtype(x, y, e, 2, limit_slow) &&
1442+
local_forall_exists_subtype(y, x, e, 0, 0);
14071443

14081444
pop_unionstate(&e->Lunions, &oldLunions);
1409-
return sub && subtype(y, x, e, 0);
1445+
return sub;
14101446
}
14111447

14121448
static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
@@ -1433,7 +1469,7 @@ static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_
14331469
}
14341470
}
14351471

1436-
static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
1472+
static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore)
14371473
{
14381474
// The depth recursion has the following shape, after simplification:
14391475
// ∀₁
@@ -1446,8 +1482,12 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
14461482

14471483
e->Lunions.used = 0;
14481484
int sub;
1485+
if (count) *count = 0;
1486+
if (noRmore) *noRmore = 1;
14491487
while (1) {
14501488
sub = exists_subtype(x, y, e, saved, &se, param);
1489+
if (count) *count = (*count < 4) ? *count + 1 : 4;
1490+
if (noRmore) *noRmore = *noRmore && e->Runions.more == 0;
14511491
if (!sub || !next_union_state(e, 0))
14521492
break;
14531493
free_env(&se);
@@ -1459,6 +1499,11 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
14591499
return sub;
14601500
}
14611501

1502+
static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
1503+
{
1504+
return _forall_exists_subtype(x, y, e, param, NULL, NULL);
1505+
}
1506+
14621507
static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
14631508
{
14641509
e->vars = NULL;

test/subtype.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,8 @@ f24521(::Type{T}, ::Type{T}) where {T} = T
14061406
@test !(Ref{Union{Int64, Val{Number}}} <: Ref{Union{Val{T}, T}} where T)
14071407
@test !(Ref{Union{Ref{Number}, Int64}} <: Ref{Union{Ref{T}, T}} where T)
14081408
@test !(Ref{Union{Val{Number}, Int64}} <: Ref{Union{Val{T}, T}} where T)
1409+
@test !(Val{Ref{Union{Int64, Ref{Number}}}} <: Val{S} where {S<:Ref{Union{Ref{T}, T}} where T})
1410+
@test !(Tuple{Ref{Union{Int64, Ref{Number}}}} <: Tuple{S} where {S<:Ref{Union{Ref{T}, T}} where T})
14091411

14101412
# issue #26180
14111413
@test !(Ref{Union{Ref{Int64}, Ref{Number}}} <: Ref{Ref{T}} where T)
@@ -2270,8 +2272,8 @@ abstract type P47654{A} end
22702272
@test_broken typeintersect(Tuple{Vector{VT}, Vector{VT}} where {N1, VT<:AbstractVector{N1}},
22712273
Tuple{Vector{VN} where {N, VN<:AbstractVector{N}}, Vector{Vector{Float64}}}) !== Union{}
22722274
#issue 40865
2273-
@test_broken Tuple{Set{Ref{Int}}, Set{Ref{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Ref{K}}}
2274-
@test_broken Tuple{Set{Val{Int}}, Set{Val{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Val{K}}}
2275+
@test Tuple{Set{Ref{Int}}, Set{Ref{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Ref{K}}}
2276+
@test Tuple{Set{Val{Int}}, Set{Val{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Val{K}}}
22752277

22762278
#issue 39099
22772279
A = Tuple{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Vararg{Int, N}}, Tuple{Vararg{Int, N}}} where N
@@ -2308,7 +2310,10 @@ end
23082310

23092311
# try to fool a greedy algorithm that picks X=Int, Y=String here
23102312
@test Tuple{Ref{Union{Int,String}}, Ref{Union{Int,String}}} <: Tuple{Ref{Union{X,Y}}, Ref{X}} where {X,Y}
2311-
# this slightly more complex case has been broken since 1.0 (worked in 0.6)
2312-
@test_broken Tuple{Ref{Union{Int,String,Missing}}, Ref{Union{Int,String}}} <: Tuple{Ref{Union{X,Y}}, Ref{X}} where {X,Y}
2313+
@test Tuple{Ref{Union{Int,String,Missing}}, Ref{Union{Int,String}}} <: Tuple{Ref{Union{X,Y}}, Ref{X}} where {X,Y}
23132314

23142315
@test !(Tuple{Any, Any, Any} <: Tuple{Any, Vararg{T}} where T)
2316+
2317+
let a = (isodd(i) ? Pair{Char, String} : Pair{String, String} for i in 1:2000)
2318+
@test Tuple{Type{Pair{Union{Char, String}, String}}, a...} <: Tuple{Type{Pair{K, V}}, Vararg{Pair{A, B} where B where A}} where V where K
2319+
end

0 commit comments

Comments
 (0)