Skip to content

Commit 4f8a7b9

Browse files
authored
fix #29269, type intersection bug in union parameters with typevars (#29406)
also fixes #25752
1 parent 6778bef commit 4f8a7b9

File tree

7 files changed

+75
-40
lines changed

7 files changed

+75
-40
lines changed

base/missing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ nonmissingtype(::Type{Any}) = Any
2727
for U in (:Nothing, :Missing)
2828
@eval begin
2929
promote_rule(::Type{$U}, ::Type{T}) where {T} = Union{T, $U}
30+
promote_rule(::Type{Union{S,$U}}, ::Type{Any}) where {S} = Any
3031
promote_rule(::Type{Union{S,$U}}, ::Type{T}) where {T,S} = Union{promote_type(T, S), $U}
3132
promote_rule(::Type{Any}, ::Type{$U}) = Any
3233
promote_rule(::Type{$U}, ::Type{Any}) = Any
@@ -37,13 +38,16 @@ end
3738
promote_rule(::Type{Union{Nothing, Missing}}, ::Type{Any}) = Any
3839
promote_rule(::Type{Union{Nothing, Missing}}, ::Type{T}) where {T} =
3940
Union{Nothing, Missing, T}
41+
promote_rule(::Type{Union{Nothing, Missing, S}}, ::Type{Any}) where {S} = Any
4042
promote_rule(::Type{Union{Nothing, Missing, S}}, ::Type{T}) where {T,S} =
4143
Union{Nothing, Missing, promote_type(T, S)}
4244

45+
convert(::Type{Union{T, Missing}}, x::Union{T, Missing}) where {T} = x
4346
convert(::Type{Union{T, Missing}}, x) where {T} = convert(T, x)
4447
# To fix ambiguities
4548
convert(::Type{Missing}, ::Missing) = missing
4649
convert(::Type{Union{Nothing, Missing}}, x::Union{Nothing, Missing}) = x
50+
convert(::Type{Union{Nothing, Missing, T}}, x::Union{Nothing, Missing, T}) where {T} = x
4751
convert(::Type{Union{Nothing, Missing}}, x) =
4852
throw(MethodError(convert, (Union{Nothing, Missing}, x)))
4953
# To print more appropriate message than "T not defined"

base/some.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ promote_rule(::Type{Some{T}}, ::Type{Nothing}) where {T} = Union{Some{T}, Nothin
1818
convert(::Type{Some{T}}, x::Some) where {T} = Some{T}(convert(T, x.value))
1919
convert(::Type{Union{Some{T}, Nothing}}, x::Some) where {T} = convert(Some{T}, x)
2020

21+
convert(::Type{Union{T, Nothing}}, x::Union{T, Nothing}) where {T} = x
2122
convert(::Type{Union{T, Nothing}}, x::Any) where {T} = convert(T, x)
22-
convert(::Type{Nothing}, x::Any) = throw(MethodError(convert, (Nothing, x)))
2323
convert(::Type{Nothing}, x::Nothing) = nothing
24+
convert(::Type{Nothing}, x::Any) = throw(MethodError(convert, (Nothing, x)))
2425

2526
function show(io::IO, x::Some)
2627
if get(io, :typeinfo, Any) == typeof(x)

src/subtype.c

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ static int var_outside(jl_stenv_t *e, jl_tvar_t *x, jl_tvar_t *y)
521521
return 0;
522522
}
523523

524-
static jl_value_t *intersect_ufirst(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth);
524+
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth);
525525

526526
// check that type var `b` is <: `a`, and update b's upper bound.
527527
static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param)
@@ -539,7 +539,7 @@ static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param)
539539
// for this to work we need to compute issub(left,right) before issub(right,left),
540540
// since otherwise the issub(a, bb.ub) check in var_gt becomes vacuous.
541541
if (e->intersection) {
542-
jl_value_t *ub = intersect_ufirst(bb->ub, a, e, bb->depth0);
542+
jl_value_t *ub = intersect_aside(bb->ub, a, e, bb->depth0);
543543
if (ub != (jl_value_t*)b)
544544
bb->ub = ub;
545545
}
@@ -1328,16 +1328,32 @@ JL_DLLEXPORT int jl_isa(jl_value_t *x, jl_value_t *t)
13281328

13291329
static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);
13301330

1331+
static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e);
1332+
1333+
// intersect in nested union environment, similar to subtype_ccheck
1334+
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth)
1335+
{
1336+
jl_value_t *res;
1337+
int savedepth = e->invdepth;
1338+
jl_unionstate_t oldRunions = e->Runions;
1339+
e->invdepth = depth;
1340+
1341+
res = intersect_all(x, y, e);
1342+
1343+
e->Runions = oldRunions;
1344+
e->invdepth = savedepth;
1345+
return res;
1346+
}
1347+
13311348
static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t *e, int8_t R, int param)
13321349
{
13331350
if (param == 2 || (!jl_has_free_typevars(x) && !jl_has_free_typevars((jl_value_t*)u))) {
1334-
jl_value_t *a=NULL, *b=NULL, *save=NULL; jl_savedenv_t se;
1335-
JL_GC_PUSH3(&a, &b, &save);
1336-
save_env(e, &save, &se);
1337-
a = R ? intersect(x, u->a, e, param) : intersect(u->a, x, e, param);
1338-
restore_env(e, NULL, &se);
1339-
b = R ? intersect(x, u->b, e, param) : intersect(u->b, x, e, param);
1340-
free(se.buf);
1351+
jl_value_t *a=NULL, *b=NULL;
1352+
JL_GC_PUSH2(&a, &b);
1353+
jl_unionstate_t oldRunions = e->Runions;
1354+
a = R ? intersect_all(x, u->a, e) : intersect_all(u->a, x, e);
1355+
b = R ? intersect_all(x, u->b, e) : intersect_all(u->b, x, e);
1356+
e->Runions = oldRunions;
13411357
jl_value_t *i = simple_join(a,b);
13421358
JL_GC_POP();
13431359
return i;
@@ -1347,21 +1363,6 @@ static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t
13471363
return R ? intersect(x, choice, e, param) : intersect(choice, x, e, param);
13481364
}
13491365

1350-
static jl_value_t *intersect_ufirst(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth)
1351-
{
1352-
jl_value_t *res;
1353-
int savedepth = e->invdepth;
1354-
e->invdepth = depth;
1355-
if (jl_is_uniontype(x) && jl_is_typevar(y))
1356-
res = intersect_union(y, (jl_uniontype_t*)x, e, 0, 0);
1357-
else if (jl_is_typevar(x) && jl_is_uniontype(y))
1358-
res = intersect_union(x, (jl_uniontype_t*)y, e, 1, 0);
1359-
else
1360-
res = intersect(x, y, e, 0);
1361-
e->invdepth = savedepth;
1362-
return res;
1363-
}
1364-
13651366
// set a variable to a non-type constant
13661367
static jl_value_t *set_var_to_const(jl_varbinding_t *bb, jl_value_t *v JL_MAYBE_UNROOTED, jl_varbinding_t *othervar)
13671368
{
@@ -1386,13 +1387,11 @@ static jl_value_t *set_var_to_const(jl_varbinding_t *bb, jl_value_t *v JL_MAYBE_
13861387

13871388
static int try_subtype_in_env(jl_value_t *a, jl_value_t *b, jl_stenv_t *e)
13881389
{
1389-
jl_value_t *root=NULL; jl_savedenv_t se; int ret=0;
1390+
jl_value_t *root=NULL; jl_savedenv_t se;
13901391
JL_GC_PUSH1(&root);
13911392
save_env(e, &root, &se);
1392-
if (subtype_in_env(a, b, e))
1393-
ret = 1;
1394-
else
1395-
restore_env(e, root, &se);
1393+
int ret = subtype_in_env(a, b, e);
1394+
restore_env(e, root, &se);
13961395
free(se.buf);
13971396
JL_GC_POP();
13981397
return ret;
@@ -1402,15 +1401,15 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
14021401
{
14031402
jl_varbinding_t *bb = lookup(e, b);
14041403
if (bb == NULL)
1405-
return R ? intersect_ufirst(a, b->ub, e, 0) : intersect_ufirst(b->ub, a, e, 0);
1404+
return R ? intersect_aside(a, b->ub, e, 0) : intersect_aside(b->ub, a, e, 0);
14061405
if (bb->lb == bb->ub && jl_is_typevar(bb->lb))
14071406
return intersect(a, bb->lb, e, param);
14081407
if (!jl_is_type(a) && !jl_is_typevar(a))
14091408
return set_var_to_const(bb, a, NULL);
14101409
int d = bb->depth0;
14111410
jl_value_t *root=NULL; jl_savedenv_t se;
14121411
if (param == 2) {
1413-
jl_value_t *ub = R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d);
1412+
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
14141413
JL_GC_PUSH2(&ub, &root);
14151414
if (!jl_has_free_typevars(ub) && !jl_has_free_typevars(bb->lb)) {
14161415
save_env(e, &root, &se);
@@ -1450,10 +1449,10 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
14501449
if (try_subtype_in_env(bb->ub, a, e))
14511450
return (jl_value_t*)b;
14521451
}
1453-
return R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d);
1452+
return R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
14541453
}
14551454
else if (bb->concrete || bb->constraintkind == 1) {
1456-
jl_value_t *ub = R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d);
1455+
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
14571456
JL_GC_PUSH1(&ub);
14581457
if (ub == jl_bottom_type || !subtype_in_env(bb->lb, a, e)) {
14591458
JL_GC_POP();
@@ -1473,7 +1472,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
14731472
return a;
14741473
}
14751474
assert(bb->constraintkind == 3);
1476-
jl_value_t *ub = R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d);
1475+
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
14771476
if (ub == jl_bottom_type)
14781477
return jl_bottom_type;
14791478
if (jl_is_typevar(a))
@@ -1494,7 +1493,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
14941493
root = NULL;
14951494
JL_GC_PUSH2(&root, &ub);
14961495
save_env(e, &root, &se);
1497-
jl_value_t *ii = R ? intersect_ufirst(a, bb->lb, e, d) : intersect_ufirst(bb->lb, a, e, d);
1496+
jl_value_t *ii = R ? intersect_aside(a, bb->lb, e, d) : intersect_aside(bb->lb, a, e, d);
14981497
if (ii == jl_bottom_type) {
14991498
restore_env(e, root, &se);
15001499
ii = (jl_value_t*)b;
@@ -2050,7 +2049,7 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
20502049
return jl_bottom_type;
20512050
jl_value_t *ub=NULL, *lb=NULL;
20522051
JL_GC_PUSH2(&lb, &ub);
2053-
ub = intersect_ufirst(xub, yub, e, xx ? xx->depth0 : 0);
2052+
ub = intersect_aside(xub, yub, e, xx ? xx->depth0 : 0);
20542053
lb = simple_join(xlb, ylb);
20552054
if (yy) {
20562055
if (lb != y)

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,9 @@ mul!(out::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, in::StridedMatrix) = o
295295
*(transD::Transpose{<:Any,<:Diagonal}, transA::Transpose{<:Any,<:RealHermSymComplexSym}) = transD * transA.parent
296296
*(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjD::Adjoint{<:Any,<:Diagonal}) = adjA.parent * adjD
297297
*(adjD::Adjoint{<:Any,<:Diagonal}, adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjD * adjA.parent
298+
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:RealHermSym}) = mul!(C, A, B.parent)
298299
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:RealHermSymComplexHerm}) = mul!(C, A, B.parent)
300+
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:RealHermSym}) = mul!(C, A, B.parent)
299301
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:RealHermSymComplexSym}) = mul!(C, A, B.parent)
300302
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:RealHermSymComplexSym}) = C .= adjoint.(A.parent.diag) .* B
301303
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:RealHermSymComplexHerm}) = C .= transpose.(A.parent.diag) .* B

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ end
176176
convert(T::Type{<:Symmetric}, m::Union{Symmetric,Hermitian}) = m isa T ? m : T(m)
177177
convert(T::Type{<:Hermitian}, m::Union{Symmetric,Hermitian}) = m isa T ? m : T(m)
178178

179-
const HermOrSym{T,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
179+
const HermOrSym{T, S} = Union{Hermitian{T,S}, Symmetric{T,S}}
180+
const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
180181
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
181182
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}
182183

@@ -427,11 +428,17 @@ mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Hermitian{T,<:StridedMatrix})
427428
*(A::AbstractMatrix, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = A * adjB.parent
428429

429430
# ambiguities with transposed AbstractMatrix methods in linalg/matmul.jl
431+
*(transA::Transpose{<:Any,<:RealHermSym}, transB::Transpose{<:Any,<:RealHermSym}) = transA * transB.parent
432+
*(transA::Transpose{<:Any,<:RealHermSym}, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = transA * transB.parent
430433
*(transA::Transpose{<:Any,<:RealHermSymComplexSym}, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = transA.parent * transB.parent
434+
*(transA::Transpose{<:Any,<:RealHermSymComplexSym}, transB::Transpose{<:Any,<:RealHermSym}) = transA.parent * transB
431435
*(transA::Transpose{<:Any,<:RealHermSymComplexSym}, transB::Transpose{<:Any,<:RealHermSymComplexHerm}) = transA.parent * transB
432436
*(transA::Transpose{<:Any,<:RealHermSymComplexHerm}, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = transA * transB.parent
437+
*(adjA::Adjoint{<:Any,<:RealHermSym}, adjB::Adjoint{<:Any,<:RealHermSym}) = adjA * adjB.parent
433438
*(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjA.parent * adjB.parent
439+
*(adjA::Adjoint{<:Any,<:RealHermSym}, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjA * adjB.parent
434440
*(adjA::Adjoint{<:Any,<:RealHermSymComplexSym}, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjA * adjB.parent
441+
*(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjB::Adjoint{<:Any,<:RealHermSym}) = adjA.parent * adjB
435442
*(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjB::Adjoint{<:Any,<:RealHermSymComplexSym}) = adjA.parent * adjB
436443

437444
# ambiguities with AbstractTriangular

test/ambiguous.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ end
275275
pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, (Type{Union{T, Nothing}} where T, Core.Compiler.Some)))
276276
pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{}}))
277277
pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{Int8}}))
278+
pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Union{Nothing,T}},Union{Nothing,T}} where T))
278279
@test need_to_handle_undef_sparam == Set()
279280
end
280281
let need_to_handle_undef_sparam =
@@ -299,6 +300,12 @@ end
299300
pop!(need_to_handle_undef_sparam, which(Base.convert, (Type{Union{T, Nothing}} where T, Some)))
300301
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{}}))
301302
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{Int8}}))
303+
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Nothing,T}},Union{Nothing,T}} where T))
304+
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Missing,T}},Union{Missing,T}} where T))
305+
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Missing,Nothing,T}},Union{Missing,Nothing,T}} where T))
306+
pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Nothing,T}},Type{Any}} where T))
307+
pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Missing,T}},Type{Any}} where T))
308+
pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Missing,Nothing,T}},Type{Any}} where T))
302309
@test need_to_handle_undef_sparam == Set()
303310
end
304311
end

test/subtype.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,10 +847,14 @@ function test_intersection()
847847
@testintersect(Ref{@UnionAll T @UnionAll S Tuple{T,S}},
848848
Ref{@UnionAll T Tuple{T,T}}, Bottom)
849849

850+
# both of these answers seem acceptable
851+
#@testintersect(Tuple{T,T} where T<:Union{UpperTriangular, UnitUpperTriangular},
852+
# Tuple{AbstractArray{T,N}, AbstractArray{T,N}} where N where T,
853+
# Union{Tuple{T,T} where T<:UpperTriangular,
854+
# Tuple{T,T} where T<:UnitUpperTriangular})
850855
@testintersect(Tuple{T,T} where T<:Union{UpperTriangular, UnitUpperTriangular},
851856
Tuple{AbstractArray{T,N}, AbstractArray{T,N}} where N where T,
852-
Union{Tuple{T,T} where T<:UpperTriangular,
853-
Tuple{T,T} where T<:UnitUpperTriangular})
857+
Tuple{T,T} where T<:Union{UpperTriangular, UnitUpperTriangular})
854858

855859
@testintersect(DataType, Type, DataType)
856860
@testintersect(DataType, Type{T} where T<:Integer, Type{T} where T<:Integer)
@@ -1358,6 +1362,17 @@ end
13581362
Tuple{Val{2}, Vararg{Val{3}}},
13591363
Union{})
13601364

1365+
# issue #25752
1366+
@testintersect(Base.RefValue, Ref{Union{Int,T}} where T,
1367+
Base.RefValue{Union{Int,T}} where T)
1368+
# issue #29269
1369+
@testintersect((Tuple{Int, Array{T}} where T),
1370+
(Tuple{Any, Vector{Union{Missing,T}}} where T),
1371+
(Tuple{Int, Vector{Union{Missing,T}}} where T))
1372+
@testintersect((Tuple{Int, Array{T}} where T),
1373+
(Tuple{Any, Vector{Union{Missing,Nothing,T}}} where T),
1374+
(Tuple{Int, Vector{Union{Missing,Nothing,T}}} where T))
1375+
13611376
# issue #29955
13621377
struct M29955{T, TV<:AbstractVector{T}}
13631378
end

0 commit comments

Comments
 (0)