Skip to content

Commit 4dd3346

Browse files
authored
backport #41363: fix equality of QRCompactWY (#41395)
* fix equality of QRCompactWY (#41363) Equality for `QRCompactWY` did not ignore the subdiagonal entries of `T` leading to nondeterministic behavior. This is pulled out from #41228, since this change should be less controversial than the other changes there and this particular bug just came up in ChainRules again.
1 parent c5eceef commit 4dd3346

File tree

4 files changed

+102
-1
lines changed

4 files changed

+102
-1
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
1616
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
1717
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
1818
using Base: hvcat_fill, IndexLinear, promote_op, promote_typeof,
19-
@propagate_inbounds, @pure, reduce, typed_vcat, require_one_based_indexing
19+
@propagate_inbounds, @pure, reduce, typed_vcat, require_one_based_indexing,
20+
splat
2021
using Base.Broadcast: Broadcasted, broadcasted
2122

2223
export

stdlib/LinearAlgebra/src/qr.jl

+34
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,40 @@ Base.iterate(S::QRCompactWY) = (S.Q, Val(:R))
127127
Base.iterate(S::QRCompactWY, ::Val{:R}) = (S.R, Val(:done))
128128
Base.iterate(S::QRCompactWY, ::Val{:done}) = nothing
129129

130+
# returns upper triangular views of all non-undef values of `qr(A).T`:
131+
#
132+
# julia> sparse(qr(A).T .== qr(A).T)
133+
# 36×100 SparseMatrixCSC{Bool, Int64} with 1767 stored entries:
134+
# ⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿
135+
# ⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿
136+
# ⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿
137+
# ⠀⠀⠀⠀⠀⠂⠛⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿
138+
# ⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⢀⠐⠙⢿⣿⣿⣿⣿
139+
# ⠀⠀⠐⠀⠀⠀⠀⠀⠀⢀⢙⣿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠁⠀⡀⠀⠙⢿⣿⣿
140+
# ⠀⠀⠐⠀⠀⠀⠀⠀⠀⠀⠄⠀⠙⢿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⣿⣿⠀⠀⠀⠀⠀⠀⡀⠀⠀⢀⠀⠀⠙⢿
141+
# ⠀⡀⠀⠀⠀⠀⠀⠀⠂⠒⠒⠀⠀⠀⠙⢿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣿⣿⠀⠀⠀⠀⠀⠀⠀⢀⠀⠀⠀⡀⠀⠀
142+
# ⠀⠀⠀⠀⠀⠀⠀⠀⣈⡀⠀⠀⠀⠀⠀⠀⠙⢿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠂⠀⢀⠀
143+
#
144+
function _triuppers_qr(T)
145+
blocksize, cols = size(T)
146+
return Iterators.map(0:div(cols - 1, blocksize)) do i
147+
n = min(blocksize, cols - i * blocksize)
148+
return UpperTriangular(view(T, 1:n, (1:n) .+ i * blocksize))
149+
end
150+
end
151+
152+
function Base.hash(F::QRCompactWY, h::UInt)
153+
return hash(F.factors, foldr(hash, _triuppers_qr(F.T); init=hash(QRCompactWY, h)))
154+
end
155+
function Base.:(==)(A::QRCompactWY, B::QRCompactWY)
156+
return A.factors == B.factors && all(splat(==), zip(_triuppers_qr.((A.T, B.T))...))
157+
end
158+
function Base.isequal(A::QRCompactWY, B::QRCompactWY)
159+
return isequal(A.factors, B.factors) && all(zip(_triuppers_qr.((A.T, B.T))...)) do (a, b)
160+
isequal(a, b)::Bool
161+
end
162+
end
163+
130164
"""
131165
QRPivoted <: Factorization
132166
+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
module TestFactorization
4+
using Test, LinearAlgebra
5+
6+
@testset "equality for factorizations - $f" for f in Any[
7+
bunchkaufman,
8+
cholesky,
9+
x -> cholesky(x, Val(true)),
10+
eigen,
11+
hessenberg,
12+
lq,
13+
lu,
14+
qr,
15+
x -> qr(x, Val(true)),
16+
svd,
17+
schur,
18+
]
19+
A = randn(3, 3)
20+
A = A * A' # ensure A is pos. def. and symmetric
21+
F, G = f(A), f(A)
22+
23+
@test F == G
24+
@test isequal(F, G)
25+
@test hash(F) == hash(G)
26+
27+
f === hessenberg && continue
28+
29+
# change all arrays in F to have eltype Float32
30+
F = typeof(F).name.wrapper(Base.mapany(1:nfields(F)) do i
31+
x = getfield(F, i)
32+
return x isa AbstractArray{Float64} ? Float32.(x) : x
33+
end...)
34+
# round all arrays in G to the nearest Float64 representable as Float32
35+
G = typeof(G).name.wrapper(Base.mapany(1:nfields(G)) do i
36+
x = getfield(G, i)
37+
return x isa AbstractArray{Float64} ? Float64.(Float32.(x)) : x
38+
end...)
39+
40+
if f === qr
41+
@test F == G
42+
@test isequal(F, G)
43+
else
44+
@test_broken F == G
45+
@test_broken isequal(F, G)
46+
end
47+
@test hash(F) == hash(G)
48+
end
49+
50+
@testset "equality of QRCompactWY" begin
51+
A = rand(100, 100)
52+
F, G = qr(A), qr(A)
53+
54+
@test F == G
55+
@test isequal(F, G)
56+
@test hash(F) == hash(G)
57+
58+
G.T[28, 100] = 42
59+
60+
@test F != G
61+
@test !isequal(F, G)
62+
@test hash(F) != hash(G)
63+
end
64+
65+
end

stdlib/LinearAlgebra/test/testgroups

+1
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ givens
2525
structuredbroadcast
2626
addmul
2727
ldlt
28+
factorization

0 commit comments

Comments
 (0)