Skip to content

Commit

Permalink
Merge pull request #207 from JuliaDiff/ox/anycompositewilldo
Browse files Browse the repository at this point in the history
Remove constraints on primal from Composites for SVD and Cholesky
  • Loading branch information
oxinabox authored Jun 11, 2020
2 parents 3a969a2 + 141bc66 commit 0e6e00c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.6.3"
version = "0.6.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
52 changes: 26 additions & 26 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!

function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
function svd_pullback(::Composite{<:SVD})
∂X = @thunk(svd_rev(F, .U, .S, .V))
function svd_pullback(Ȳ::Composite)
∂X = @thunk(svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V))
return (NO_FIELDS, ∂X)
end
return F, svd_pullback
end

function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD
function getproperty_svd_pullback()
function getproperty_svd_pullback(Ȳ)
C = Composite{T}
∂F = if x === :U
C(U=,)
C(U=Ȳ,)
elseif x === :S
C(S=,)
C(S=Ȳ,)
elseif x === :V
C(V=,)
C(V=Ȳ,)
elseif x === :Vt
# TODO: https://github.com/JuliaDiff/ChainRules.jl/issues/106
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
Expand All @@ -32,8 +32,8 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD
return getproperty(F, x), getproperty_svd_pullback
end

# When not `Zero`s expect `::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix`
function svd_rev(USV::SVD, , s̄, V̄)
# When not `Zero`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix`
function svd_rev(USV::SVD, Ū, s̄, V̄)
# Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default
U = USV.U
s = USV.S
Expand All @@ -49,7 +49,7 @@ function svd_rev(USV::SVD, Ū, s̄, V̄)
# place functions here are significantly faster than their out-of-place, naively
# implemented counterparts, and allocate no additional memory.
Ut = U'
FUᵀŪ = _mulsubtrans!(Ut*, F) # F .* (UᵀŪ - ŪᵀU)
FUᵀŪ = _mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU)
FVᵀV̄ = _mulsubtrans!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV)
ImUUᵀ = _eyesubx!(U*Ut) # I - UUᵀ
ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ
Expand All @@ -58,11 +58,11 @@ function svd_rev(USV::SVD, Ū, s̄, V̄)
=isa AbstractZero ?: Diagonal(s̄)

# TODO: consider using MuladdMacro here
= _add!(U * FUᵀŪ * S, ImUUᵀ * ( / S)) * Vt
= _add!(, U ** Vt)
= _add!(, U * _add!(S * FVᵀV̄ * Vt, (S \') * ImVVᵀ))
Ā = _add!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt
Ā = _add!(Ā, U ** Vt)
Ā = _add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \') * ImVVᵀ))

return
return Ā
end

#####
Expand All @@ -71,31 +71,31 @@ end

function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real})
F = cholesky(X)
function cholesky_pullback(::Composite{<:Cholesky})
function cholesky_pullback(Ȳ::Composite)
∂X = if F.uplo === 'U'
@thunk(chol_blocked_rev(.U, F.U, 25, true))
@thunk(chol_blocked_rev(Ȳ.U, F.U, 25, true))
else
@thunk(chol_blocked_rev(.L, F.L, 25, false))
@thunk(chol_blocked_rev(Ȳ.L, F.L, 25, false))
end
return (NO_FIELDS, ∂X)
end
return F, cholesky_pullback
end

function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
function getproperty_cholesky_pullback()
function getproperty_cholesky_pullback(Ȳ)
C = Composite{T}
∂F = @thunk if x === :U
if F.uplo === 'U'
C(U=UpperTriangular(),)
C(U=UpperTriangular(Ȳ),)
else
C(L=LowerTriangular('),)
C(L=LowerTriangular(Ȳ'),)
end
elseif x === :L
if F.uplo === 'L'
C(L=LowerTriangular(),)
C(L=LowerTriangular(Ȳ),)
else
C(U=UpperTriangular('),)
C(U=UpperTriangular(Ȳ'),)
end
end
return NO_FIELDS, ∂F, DoesNotExist()
Expand Down Expand Up @@ -159,14 +159,14 @@ function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
end

"""
chol_unblocked_rev!(::AbstractMatrix, L::AbstractMatrix, upper::Bool)
chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool)
Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner.
If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle
of `` and `L` respectively. If `upper` is `true` then they are computed and stored in the
upper triangles. If at input `upper` is `false` and `tril() = L̄`, at output
`tril() = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and
`triu() = triu()`, at output `triu() = triu(Σ̄)` where `Σ = UᵀU`.
of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the
upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output
`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and
`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`.
"""
function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real
n = checksquare(Σ̄)
Expand Down

2 comments on commit 0e6e00c

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/16201

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.4 -m "<description of version>" 0e6e00c2d254e3188aa130eaa71532be56e32957
git push origin v0.6.4

Please sign in to comment.