Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ using LinearAlgebra

MatrixAlgebraKit.iszerotangent(::AbstractZero) = true

@non_differentiable MatrixAlgebraKit.select_algorithm(args...)
@non_differentiable MatrixAlgebraKit.initialize_output(args...)
@non_differentiable MatrixAlgebraKit.check_input(args...)
@non_differentiable MatrixAlgebraKit.isisometry(args...)
@non_differentiable MatrixAlgebraKit.isunitary(args...)

function ChainRulesCore.rrule(::typeof(copy_input), f, A)
project = ProjectTo(A)
copy_input_pullback(ΔA) = (NoTangent(), NoTangent(), project(unthunk(ΔA)))
Expand All @@ -35,18 +41,12 @@ for qr_f in (:qr_compact, :qr_full)
end
end
end
function ChainRulesCore.rrule(::typeof(qr_null!), A::AbstractMatrix, N, alg)
function ChainRulesCore.rrule(::typeof(qr_null!), A, N, alg)
Ac = copy_input(qr_full, A)
QR = initialize_output(qr_full!, A, alg)
Q, R = qr_full!(Ac, QR, alg)
N = copy!(N, view(Q, 1:size(A, 1), (size(A, 2) + 1):size(A, 1)))
N = qr_null!(Ac, N, alg)
function qr_null_pullback(ΔN)
ΔA = zero(A)
(m, n) = size(A)
minmn = min(m, n)
ΔQ = zero!(similar(A, (m, m)))
view(ΔQ, 1:m, (minmn + 1):m) .= unthunk.(ΔN)
MatrixAlgebraKit.qr_pullback!(ΔA, A, (Q, R), (ΔQ, ZeroTangent()))
MatrixAlgebraKit.qr_null_pullback!(ΔA, A, N, unthunk(ΔN))
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function qr_null_pullback(::ZeroTangent) # is this extra definition useful?
Expand All @@ -73,18 +73,12 @@ for lq_f in (:lq_compact, :lq_full)
end
end
end
function ChainRulesCore.rrule(::typeof(lq_null!), A::AbstractMatrix, Nᴴ, alg)
function ChainRulesCore.rrule(::typeof(lq_null!), A, Nᴴ, alg)
Ac = copy_input(lq_full, A)
LQ = initialize_output(lq_full!, A, alg)
L, Q = lq_full!(Ac, LQ, alg)
Nᴴ = copy!(Nᴴ, view(Q, (size(A, 1) + 1):size(A, 2), 1:size(A, 2)))
Nᴴ = lq_null!(Ac, Nᴴ, alg)
function lq_null_pullback(ΔNᴴ)
ΔA = zero(A)
(m, n) = size(A)
minmn = min(m, n)
ΔQ = zero!(similar(A, (n, n)))
view(ΔQ, (minmn + 1):n, 1:n) .= unthunk.(ΔNᴴ)
MatrixAlgebraKit.lq_pullback!(ΔA, A, (L, Q), (ZeroTangent(), ΔQ))
MatrixAlgebraKit.lq_null_pullback!(ΔA, A, Nᴴ, unthunk(ΔNᴴ))
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function lq_null_pullback(::ZeroTangent) # is this extra definition useful?
Expand Down
5 changes: 3 additions & 2 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
)
eval(
Expr(
:public, :qr_pullback!, :lq_pullback!, :svd_pullback!, :svd_trunc_pullback!,
:public, :left_polar_pullback!, :right_polar_pullback!,
:qr_pullback!, :qr_null_pullback!, :lq_pullback!, :lq_null_pullback!,
:eig_pullback!, :eig_trunc_pullback!, :eigh_pullback!, :eigh_trunc_pullback!,
:left_polar_pullback!, :right_polar_pullback!
:svd_pullback!, :svd_trunc_pullback!
)
)
end
Expand Down
1 change: 1 addition & 0 deletions src/common/pullbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ ecosystems
function iszerotangent end

iszerotangent(::Any) = false
iszerotangent(::Nothing) = true
25 changes: 25 additions & 0 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,28 @@ function lq_pullback!(
ΔA1 .+= ΔQ̃
return ΔA
end

"""
lq_null_pullback(ΔA, A, Nᴴ, ΔNᴴ)

Adds the pullback from the left nullspace of `A` to `ΔA`, given the nullspace basis
`Nᴴ` and its cotangent `ΔNᴴ` of `lq_null(A)`.

See also [`lq_pullback!`](@ref).
"""
function lq_null_pullback!(
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
tol::Real = default_pullback_gaugetol(A),
gauge_atol::Real = tol
)
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
NᴴΔN = Nᴴ * ΔNᴴ'
Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2)
Δgauge < tol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
ΔA = mul!(ΔA, X, Nᴴ, -1, 1)
end
return ΔA
end
26 changes: 26 additions & 0 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,29 @@ function qr_pullback!(
ΔA1 .+= ΔQ̃
return ΔA
end

"""
qr_null_pullback(ΔA, A, N, ΔN)

Adds the pullback from the right nullspace of `A` to `ΔA`, given the nullspace basis
`N` and its cotangent `ΔN` of `qr_null(A)`.

See also [`qr_pullback!`](@ref).
"""
function qr_null_pullback!(
ΔA::AbstractMatrix, A, N, ΔN;
tol::Real = default_pullback_gaugetol(A),
gauge_atol::Real = tol
)
if !iszerotangent(ΔN) && size(N, 2) > 0
NᴴΔN = N' * ΔN
Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2)
Δgauge < tol ||
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"

Q, R = qr_compact(A; positive = true)
X = rdiv!(ΔN' * Q, UpperTriangular(R)')
ΔA = mul!(ΔA, N, X, -1, 1)
end
return ΔA
end