Skip to content

Commit

Permalink
Add SVD factorization rrule (#31)
Browse files Browse the repository at this point in the history
This adds `rrule`s for the SVD factorization as well as an accompanying
`rrule` for `getproperty` on `SVD` objects. The definitions are ported
from Nabla.
  • Loading branch information
ararslan authored May 24, 2019
1 parent 45c79e9 commit c7537c0
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Cassette = "^0.2"
FDM = "^0.4"
FDM = "^0.5"
julia = "^1.0"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include("rules/broadcast.jl")
include("rules/linalg/dense.jl")
include("rules/linalg/diagonal.jl")
include("rules/linalg/symmetric.jl")
include("rules/linalg/factorization.jl")
include("rules/blas.jl")
include("rules/nanmath.jl")
include("rules/specialfunctions.jl")
Expand Down
81 changes: 81 additions & 0 deletions src/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#####
##### `svd`
#####

function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
∂X = Rule() do::NamedTuple{(:U,:S,:V)}
svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)
end
return F, ∂X
end

function rrule(::typeof(getproperty), F::SVD, x::Symbol)
if x === :U
return F.U, (Rule(Ȳ->(U=Ȳ, S=zero(F.S), V=zero(F.V))), DNERule())
elseif x === :S
return F.S, (Rule(Ȳ->(U=zero(F.U), S=Ȳ, V=zero(F.V))), DNERule())
elseif x === :V
return F.V, (Rule(Ȳ->(U=zero(F.U), S=zero(F.S), V=Ȳ)), DNERule())
elseif x === :Vt
return F.Vt, (Rule(Ȳ->(U=zero(F.U), S=zero(F.S), V=')), DNERule())
end
end

function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix)
# Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default
U = USV.U
s = USV.S
V = USV.V
Vt = USV.Vt

k = length(s)
T = eltype(s)
F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k]

# We do a lot of matrix operations here, so we'll try to be memory-friendly and do
# as many of the computations in-place as possible. Benchmarking shows that the in-
# place functions here are significantly faster than their out-of-place, naively
# implemented counterparts, and allocate no additional memory.
Ut = U'
FUᵀŪ = _mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU)
FVᵀV̄ = _mulsubtrans!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV)
ImUUᵀ = _eyesubx!(U*Ut) # I - UUᵀ
ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ

S = Diagonal(s)
= Diagonal(s̄)

= _add!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt
_add!(Ā, U ** Vt)
_add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \') * ImVVᵀ))

return
end

function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
if i == j
X[i,i] = zero(T)
else
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
X
end

function _eyesubx!(X::AbstractMatrix{T}) where T<:Real
n, m = size(X)
@inbounds for j = 1:m, i = 1:n
X[i,j] = (i == j) - X[i,j]
end
X
end

function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
X
end
24 changes: 24 additions & 0 deletions test/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@testset "Factorizations" begin
@testset "svd" begin
rng = MersenneTwister(2)
for n in [4, 6, 10], m in [3, 5, 10]
X = randn(rng, n, m)
F, dX = rrule(svd, X)
for p in [:U, :S, :V, :Vt]
Y, (dF, dp) = rrule(getproperty, F, p)
@test dp isa ChainRules.DNERule
= randn(rng, size(Y)...)
X̄_ad = dX(dF(Ȳ))
X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
end
end
@testset "Helper functions" begin
X = randn(rng, 10, 10)
Y = randn(rng, 10, 10)
@test ChainRules._mulsubtrans!(copy(X), Y) Y .* (X - X')
@test ChainRules._eyesubx!(copy(X)) I - X
@test ChainRules._add!(copy(X), Y) X + Y
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include("test_util.jl")
include(joinpath("rules", "linalg", "dense.jl"))
include(joinpath("rules", "linalg", "diagonal.jl"))
include(joinpath("rules", "linalg", "symmetric.jl"))
include(joinpath("rules", "linalg", "factorization.jl"))
end
include(joinpath("rules", "broadcast.jl"))
include(joinpath("rules", "blas.jl"))
Expand Down

0 comments on commit c7537c0

Please sign in to comment.