Skip to content

Commit

Permalink
Merge pull request #42 from TuringLang/mt/fix_linalg
Browse files Browse the repository at this point in the history
Some linear algebra fixes
  • Loading branch information
mohamed82008 authored Mar 14, 2020
2 parents f32c534 + 7dbba19 commit c01566e
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.5"
Zygote = "0.4.7"
Zygote = "0.4.10"
ZygoteRules = "0.2"
julia = "1"

Expand Down
2 changes: 1 addition & 1 deletion src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, @grad, data
using SpecialFunctions: logabsgamma, digamma
using ZygoteRules: ZygoteRules, @adjoint, pullback
using LinearAlgebra: copytri!
using LinearAlgebra: copytri!, AbstractTriangular
using Distributions: AbstractMvLogNormal,
ContinuousMultivariateDistribution
using DiffRules, SpecialFunctions, FillArrays
Expand Down
73 changes: 38 additions & 35 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,42 @@ end

## Linear algebra ##

LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
# Work around https://github.com/FluxML/Tracker.jl/pull/9#issuecomment-480051767

upper(A::AbstractMatrix) = UpperTriangular(A)
lower(A::AbstractMatrix) = LowerTriangular(A)
function upper(C::Cholesky)
if C.uplo == 'U'
return upper(C.factors)
else
return copy(lower(C.factors)')
end
end
function lower(C::Cholesky)
if C.uplo == 'U'
return copy(upper(C.factors)')
else
return lower(C.factors)
end
end

LinearAlgebra.LowerTriangular(A::TrackedMatrix) = lower(A)
lower(A::TrackedMatrix) = track(lower, A)
@grad lower(A) = lower(Tracker.data(A)), ∇ -> (lower(∇),)

LinearAlgebra.UpperTriangular(A::TrackedMatrix) = upper(A)
upper(A::TrackedMatrix) = track(upper, A)
@grad upper(A) = upper(Tracker.data(A)), ∇ -> (upper(∇),)

function Base.copy(
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
) where {T <: Real}
return track(copy, A)
end
@grad function Base.copy(
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
) where {T <: Real}
return copy(data(A)), ∇ -> (copy(∇),)
end

function LinearAlgebra.cholesky(A::TrackedMatrix; check=true)
Expand All @@ -57,40 +90,10 @@ function turing_chol(A::AbstractMatrix, check)
end
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
@grad function turing_chol(A::AbstractMatrix, check)
C, back = pullback(unsafe_cholesky, data(A), data(check))
C, back = pullback(_turing_chol, data(A), data(check))
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
end

unsafe_cholesky(x, check) = cholesky(x, check=check)
@adjoint function unsafe_cholesky::Real, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || return (zero(Σ), nothing)
.factors[1, 1] / (2 * C.U[1, 1]), nothing)
end
end
@adjoint function unsafe_cholesky::Diagonal, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || (Diagonal(zero(diag.factors))), nothing)
(Diagonal(diag.factors) .* inv.(2 .* C.factors.diag)), nothing)
end
end
@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || return (zero.factors), nothing)
U, Ū = C.U, Δ.factors
Σ̄ =* U'
Σ̄ = copytri!(Σ̄, 'U')
Σ̄ = ldiv!(U, Σ̄)
BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄)
@inbounds for n in diagind(Σ̄)
Σ̄[n] /= 2
end
return (UpperTriangular(Σ̄), nothing)
end
end
_turing_chol(x, check) = cholesky(x, check=check)

# Specialised logdet for cholesky to target the triangle directly.
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])
Expand Down
7 changes: 0 additions & 7 deletions test/others.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
using StatsBase: entropy

if get_stage() in ("Others", "all")
@testset "unsafe_cholesky" begin
A = rand(3, 3); A = A + A' + 3I
@test Matrix(DistributionsAD.unsafe_cholesky(A, true)) == Matrix(cholesky(A))
@test !issuccess(DistributionsAD.unsafe_cholesky(rand(3,3), false))
@test_throws PosDefException DistributionsAD.unsafe_cholesky(rand(3,3), true)
end

@testset "TuringWishart" begin
dim = 3
A = Matrix{Float64}(I, dim, dim)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using DistributionsAD, Test, LinearAlgebra, Combinatorics
using ForwardDiff: Dual
using StatsFuns: binomlogpdf, logsumexp
const FDM = FiniteDifferences
using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringUniform, unsafe_cholesky
using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringUniform
using Distributions: meanlogdet

include("test_utils.jl")
Expand Down

0 comments on commit c01566e

Please sign in to comment.