Skip to content

Commit

Permalink
Merge pull request #740 from ElOceanografo/spdiag
Browse files Browse the repository at this point in the history
Add rrule for spdiagm
  • Loading branch information
oxinabox authored Oct 24, 2023
2 parents e3b8bf5 + 7a6d648 commit aa5abed
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ end

function _diagm_back(p, ȳ)
k, v = p
d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix
d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix
return Tangent{typeof(p)}(second = d)
end

Expand Down
23 changes: 23 additions & 0 deletions src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,26 @@ function rrule(::typeof(det), x::SparseMatrixCSC)
end
return Ω, det_pullback
end


function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)

function spdiagm_pullback(ȳ)
return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
end
return spdiagm(m, n, kv...), spdiagm_pullback
end

function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...)
function spdiagm_pullback(ȳ)
return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
end
return spdiagm(kv...), spdiagm_pullback
end

function rrule(::typeof(spdiagm), v::AbstractVector)
function spdiagm_pullback(ȳ)
return (NoTangent(), diag(unthunk(ȳ)))
end
return spdiagm(v), spdiagm_pullback
end
47 changes: 46 additions & 1 deletion test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,51 @@ end
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4)
end

# copied over from test/rulesets/LinearAlgebra/structured
@testset "spdiagm" begin
@testset "without size" begin
M, N = 7, 9
s = (8, 8)
a = randn(M)
b = randn(M)
c = randn(M - 1)
= randn(s)
ps = (0 => a, 1 => b, 0 => c)
y, back = rrule(spdiagm, ps...)
@test y == spdiagm(ps...)
∂self, ∂pa, ∂pb, ∂pc = back(ȳ)
@test ∂self === NoTangent()
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c)
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
∂px = unthunk(∂px)
@test ∂px isa Tangent{typeof(p)}
@test ∂px.first isa AbstractZero
@test ∂px.second ∂x_fd
end
end
@testset "with size" begin
M, N = 7, 9
a = randn(M)
b = randn(M)
c = randn(M - 1)
= randn(M, N)
ps = (0 => a, 1 => b, 0 => c)
y, back = rrule(spdiagm, M, N, ps...)
@test y == spdiagm(M, N, ps...)
∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ)
@test ∂self === NoTangent()
@test ∂M === NoTangent()
@test ∂N === NoTangent()
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c)
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
∂px = unthunk(∂px)
@test ∂px isa Tangent{typeof(p)}
@test ∂px.first isa AbstractZero
@test ∂px.second ∂x_fd
end
end
end

@testset "findnz" begin
A = sprand(5, 5, 0.5)
dA = similar(A)
Expand All @@ -42,4 +87,4 @@ end
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
end
end

0 comments on commit aa5abed

Please sign in to comment.