From 4d46dc7f0cb0de22e85b15a1a2e08ddd52bb668d Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Mon, 10 Jun 2019 15:12:59 -0700 Subject: [PATCH] Move some code around (NFC) * Move the BLAS definitions into the `linalg` directory. * Move some general optimizations into a specific utilities file for linear algebra definitions. * Consolidate structured matrix operations, e.g. diagonal and symmetric, into one file. --- src/ChainRules.jl | 6 +++--- src/rules/{ => linalg}/blas.jl | 0 src/rules/linalg/diagonal.jl | 2 -- src/rules/linalg/factorization.jl | 27 -------------------------- src/rules/linalg/structured.jl | 18 +++++++++++++++++ src/rules/linalg/symmetric.jl | 4 ---- src/rules/linalg/utils.jl | 32 +++++++++++++++++++++++++++++++ 7 files changed, 53 insertions(+), 36 deletions(-) rename src/rules/{ => linalg}/blas.jl (100%) delete mode 100644 src/rules/linalg/diagonal.jl create mode 100644 src/rules/linalg/structured.jl delete mode 100644 src/rules/linalg/symmetric.jl create mode 100644 src/rules/linalg/utils.jl diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 0b0e775d3..cf812063b 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -13,11 +13,11 @@ include("rules.jl") include("rules/base.jl") include("rules/array.jl") include("rules/broadcast.jl") +include("rules/linalg/utils.jl") +include("rules/linalg/blas.jl") include("rules/linalg/dense.jl") -include("rules/linalg/diagonal.jl") -include("rules/linalg/symmetric.jl") +include("rules/linalg/structured.jl") include("rules/linalg/factorization.jl") -include("rules/blas.jl") include("rules/nanmath.jl") include("rules/specialfunctions.jl") diff --git a/src/rules/blas.jl b/src/rules/linalg/blas.jl similarity index 100% rename from src/rules/blas.jl rename to src/rules/linalg/blas.jl diff --git a/src/rules/linalg/diagonal.jl b/src/rules/linalg/diagonal.jl deleted file mode 100644 index cd9a40959..000000000 --- a/src/rules/linalg/diagonal.jl +++ /dev/null @@ -1,2 +0,0 @@ -rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag) -rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal) diff --git a/src/rules/linalg/factorization.jl b/src/rules/linalg/factorization.jl index 0c04618bc..72527fcc6 100644 --- a/src/rules/linalg/factorization.jl +++ b/src/rules/linalg/factorization.jl @@ -59,33 +59,6 @@ function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::Abstra return Ā end -function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real - k = size(X, 1) - @inbounds for j = 1:k, i = 1:j # Iterate the upper triangle - if i == j - X[i,i] = zero(T) - else - X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j]) - end - end - X -end - -function _eyesubx!(X::AbstractMatrix{T}) where T<:Real - n, m = size(X) - @inbounds for j = 1:m, i = 1:n - X[i,j] = (i == j) - X[i,j] - end - X -end - -function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real - @inbounds for i = eachindex(X, Y) - X[i] += Y[i] - end - X -end - ##### ##### `cholesky` ##### diff --git a/src/rules/linalg/structured.jl b/src/rules/linalg/structured.jl new file mode 100644 index 000000000..34b19339c --- /dev/null +++ b/src/rules/linalg/structured.jl @@ -0,0 +1,18 @@ +# Structured matrices + +##### +##### `Diagonal` +##### + +rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag) + +rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal) + +##### +##### `Symmetric` +##### + +rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back) + +_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ) +_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ diff --git a/src/rules/linalg/symmetric.jl b/src/rules/linalg/symmetric.jl deleted file mode 100644 index 4b1f861d6..000000000 --- a/src/rules/linalg/symmetric.jl +++ /dev/null @@ -1,4 +0,0 @@ -rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back) - -_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ) -_symmetric_back(ΔΩ::Union{Diagonal, UpperTriangular}) = ΔΩ diff --git a/src/rules/linalg/utils.jl b/src/rules/linalg/utils.jl new file mode 100644 index 000000000..ed9a9cb10 --- /dev/null +++ b/src/rules/linalg/utils.jl @@ -0,0 +1,32 @@ +# Some utility functions for optimizing linear algebra operations that aren't specific +# to any particular rule definition + +# F .* (X - X'), overwrites X +function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real + k = size(X, 1) + @inbounds for j = 1:k, i = 1:j # Iterate the upper triangle + if i == j + X[i,i] = zero(T) + else + X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j]) + end + end + X +end + +# I - X, overwrites X +function _eyesubx!(X::AbstractMatrix) + n, m = size(X) + @inbounds for j = 1:m, i = 1:n + X[i,j] = (i == j) - X[i,j] + end + X +end + +# X + Y, overwrites X +function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real + @inbounds for i = eachindex(X, Y) + X[i] += Y[i] + end + X +end