Skip to content

Commit

Permalink
Move some code around (NFC)
Browse files Browse the repository at this point in the history
* Move the BLAS definitions into the `linalg` directory.
* Move some general optimizations into a specific utilities file for
  linear algebra definitions.
* Consolidate structured matrix operations, e.g. diagonal and symmetric,
  into one file.
  • Loading branch information
ararslan committed Jun 11, 2019
1 parent 0e551fb commit 4d46dc7
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 36 deletions.
6 changes: 3 additions & 3 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ include("rules.jl")
include("rules/base.jl")
include("rules/array.jl")
include("rules/broadcast.jl")
include("rules/linalg/utils.jl")
include("rules/linalg/blas.jl")
include("rules/linalg/dense.jl")
include("rules/linalg/diagonal.jl")
include("rules/linalg/symmetric.jl")
include("rules/linalg/structured.jl")
include("rules/linalg/factorization.jl")
include("rules/blas.jl")
include("rules/nanmath.jl")
include("rules/specialfunctions.jl")

Expand Down
File renamed without changes.
2 changes: 0 additions & 2 deletions src/rules/linalg/diagonal.jl

This file was deleted.

27 changes: 0 additions & 27 deletions src/rules/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,33 +59,6 @@ function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::Abstra
return
end

function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
if i == j
X[i,i] = zero(T)
else
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
X
end

function _eyesubx!(X::AbstractMatrix{T}) where T<:Real
n, m = size(X)
@inbounds for j = 1:m, i = 1:n
X[i,j] = (i == j) - X[i,j]
end
X
end

function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
X
end

#####
##### `cholesky`
#####
Expand Down
18 changes: 18 additions & 0 deletions src/rules/linalg/structured.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Structured matrices

#####
##### `Diagonal`
#####

rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag)

rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal)

#####
##### `Symmetric`
#####

rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back)

_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ
4 changes: 0 additions & 4 deletions src/rules/linalg/symmetric.jl

This file was deleted.

32 changes: 32 additions & 0 deletions src/rules/linalg/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Some utility functions for optimizing linear algebra operations that aren't specific
# to any particular rule definition

# F .* (X - X'), overwrites X
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
k = size(X, 1)
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
if i == j
X[i,i] = zero(T)
else
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
end
end
X
end

# I - X, overwrites X
function _eyesubx!(X::AbstractMatrix)
n, m = size(X)
@inbounds for j = 1:m, i = 1:n
X[i,j] = (i == j) - X[i,j]
end
X
end

# X + Y, overwrites X
function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
X
end

0 comments on commit 4d46dc7

Please sign in to comment.