Skip to content

Commit

Permalink
Merge pull request #51 from JuliaDiff/aa/restructure
Browse files Browse the repository at this point in the history
Minor code movement (NFC) and port some matrix types from Nabla
  • Loading branch information
ararslan authored Jun 13, 2019
2 parents 0e551fb + 76ea4b8 commit dc4adb0
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 46 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ julia = "^1.0"

[extras]
FDM = "e25cca7e-83ef-51fa-be6c-dfe2a3123128"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "FDM"]
test = ["FDM", "Random", "Test"]
6 changes: 3 additions & 3 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ include("rules.jl")
include("rules/base.jl")
include("rules/array.jl")
include("rules/broadcast.jl")
include("rules/linalg/utils.jl")
include("rules/linalg/blas.jl")
include("rules/linalg/dense.jl")
include("rules/linalg/diagonal.jl")
include("rules/linalg/symmetric.jl")
include("rules/linalg/structured.jl")
include("rules/linalg/factorization.jl")
include("rules/blas.jl")
include("rules/nanmath.jl")
include("rules/specialfunctions.jl")

Expand Down
File renamed without changes.
2 changes: 0 additions & 2 deletions src/rules/linalg/diagonal.jl

This file was deleted.

27 changes: 0 additions & 27 deletions src/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,33 +59,6 @@ function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::Abstra
return
end

function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
if i == j
X[i,i] = zero(T)
else
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
X
end

function _eyesubx!(X::AbstractMatrix{T}) where T<:Real
n, m = size(X)
@inbounds for j = 1:m, i = 1:n
X[i,j] = (i == j) - X[i,j]
end
X
end

function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
X
end

#####
##### `cholesky`
#####
Expand Down
47 changes: 47 additions & 0 deletions src/rules/linalg/structured.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Structured matrices

#####
##### `Diagonal`
#####

rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag)

rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal)

#####
##### `Symmetric`
#####

rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back)

_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ

#####
##### `Adjoint`
#####

# TODO: Deal with complex-valued arrays as well
rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) = Adjoint(A), Rule(adjoint)
rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) = Adjoint(A), Rule(vecadjoint)

rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) = adjoint(A), Rule(adjoint)
rrule(::typeof(adjoint), A::AbstractVector{<:Real}) = adjoint(A), Rule(vecadjoint)

#####
##### `Transpose`
#####

rrule(::Type{<:Transpose}, A::AbstractMatrix) = Transpose(A), Rule(transpose)
rrule(::Type{<:Transpose}, A::AbstractVector) = Transpose(A), Rule(vectranspose)

rrule(::typeof(transpose), A::AbstractMatrix) = transpose(A), Rule(transpose)
rrule(::typeof(transpose), A::AbstractVector) = transpose(A), Rule(vectranspose)

#####
##### Triangular matrices
#####

rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), Rule(Matrix)

rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), Rule(Matrix)
4 changes: 0 additions & 4 deletions src/rules/linalg/symmetric.jl

This file was deleted.

32 changes: 32 additions & 0 deletions src/rules/linalg/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Some utility functions for optimizing linear algebra operations that aren't specific
# to any particular rule definition

# F .* (X - X'), overwrites X
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
if i == j
X[i,i] = zero(T)
else
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
X
end

# I - X, overwrites X
function _eyesubx!(X::AbstractMatrix)
n, m = size(X)
@inbounds for j = 1:m, i = 1:n
X[i,j] = (i == j) - X[i,j]
end
X
end

# X + Y, overwrites X
function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
X
end
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "diagonal" begin
@testset "Structured Matrices" begin
@testset "Diagonal" begin
rng, N = MersenneTwister(123456), 3
rrule_test(Diagonal, randn(rng, N, N), (randn(rng, N), randn(rng, N)))
Expand All @@ -14,4 +14,20 @@
rrule_test(diag, randn(rng, N), (randn(rng, N, N), Diagonal(randn(rng, N))))
rrule_test(diag, randn(rng, N), (Diagonal(randn(rng, N)), Diagonal(randn(rng, N))))
end
@testset "Symmetric" begin
rng, N = MersenneTwister(123456), 3
rrule_test(Symmetric, randn(rng, N, N), (randn(rng, N, N), randn(rng, N, N)))
end
@testset "$f" for f in (Adjoint, adjoint, Transpose, transpose)
rng = MersenneTwister(32)
n = 5
m = 3
rrule_test(f, randn(rng, m, n), (randn(rng, n, m), randn(rng, n, m)))
rrule_test(f, randn(rng, 1, n), (randn(rng, n), randn(rng, n)))
end
@testset "$T" for T in (UpperTriangular, LowerTriangular)
rng = MersenneTwister(33)
n = 5
rrule_test(T, T(randn(rng, n, n)), (randn(rng, n, n), randn(rng, n, n)))
end
end
6 changes: 0 additions & 6 deletions test/rules/linalg/symmetric.jl

This file was deleted.

3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ include("test_util.jl")
include(joinpath("rules", "array.jl"))
@testset "linalg" begin
include(joinpath("rules", "linalg", "dense.jl"))
include(joinpath("rules", "linalg", "diagonal.jl"))
include(joinpath("rules", "linalg", "symmetric.jl"))
include(joinpath("rules", "linalg", "structured.jl"))
include(joinpath("rules", "linalg", "factorization.jl"))
end
include(joinpath("rules", "broadcast.jl"))
Expand Down

0 comments on commit dc4adb0

Please sign in to comment.