Skip to content

Commit

Permalink
Port Adjoint, Transpose, and triangular from Nabla
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Jun 11, 2019
1 parent 4d46dc7 commit 51e52e8
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/rules/linalg/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,28 @@ rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_ba

_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ

#####
##### `Adjoint`
#####

# TODO: Deal with complex-valued arrays as well
rrule(::Type{<:Adjoint}, A::AbstractVecOrMat{<:Real}) = Adjoint(A), Rule(adjoint)

rrule(::typeof(adjoint), A::AbstractVecOrMat{<:Real}) = adjoint(A), Rule(adjoint)

#####
##### `Transpose`
#####

rrule(::Type{<:Transpose}, A::AbstractVecOrMat) = Transpose(A), Rule(transpose)

rrule(::typeof(transpose), A::AbstractVecOrMat) = transpose(A), Rule(transpose)

#####
##### Triangular matrices
#####

rrule(::Type{<:UpperTriangular}, A::AbstractMatrix) = UpperTriangular(A), Rule(Matrix)

rrule(::Type{<:LowerTriangular}, A::AbstractMatrix) = LowerTriangular(A), Rule(Matrix)

0 comments on commit 51e52e8

Please sign in to comment.