diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index da153d14e..245335774 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -96,7 +96,7 @@ end function _diagm_back(p, ȳ) k, v = p - d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix return Tangent{typeof(p)}(second = d) end diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 8ee8f8cd0..06de7a135 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -137,3 +137,26 @@ function rrule(::typeof(det), x::SparseMatrixCSC) end return Ω, det_pullback end + + +function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) + + function spdiagm_pullback(ȳ) + return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(m, n, kv...), spdiagm_pullback +end + +function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...) + function spdiagm_pullback(ȳ) + return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(kv...), spdiagm_pullback +end + +function rrule(::typeof(spdiagm), v::AbstractVector) + function spdiagm_pullback(ȳ) + return (NoTangent(), diag(unthunk(ȳ))) + end + return spdiagm(v), spdiagm_pullback +end diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index 03f1052c2..283452a8a 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -18,6 +18,51 @@ end test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4) end +# copied over from test/rulesets/LinearAlgebra/structured +@testset "spdiagm" begin + @testset "without size" begin + M, N = 7, 9 + s = (8, 8) + a = randn(M) + b = randn(M) + c = randn(M - 1) + ȳ = randn(s) + ps = (0 => a, 1 => b, 0 => c) + y, back = rrule(spdiagm, ps...) + @test y == spdiagm(ps...) + ∂self, ∂pa, ∂pb, ∂pc = back(ȳ) + @test ∂self === NoTangent() + ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c) + for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) + ∂px = unthunk(∂px) + @test ∂px isa Tangent{typeof(p)} + @test ∂px.first isa AbstractZero + @test ∂px.second ≈ ∂x_fd + end + end + @testset "with size" begin + M, N = 7, 9 + a = randn(M) + b = randn(M) + c = randn(M - 1) + ȳ = randn(M, N) + ps = (0 => a, 1 => b, 0 => c) + y, back = rrule(spdiagm, M, N, ps...) + @test y == spdiagm(M, N, ps...) + ∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ) + @test ∂self === NoTangent() + @test ∂M === NoTangent() + @test ∂N === NoTangent() + ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c) + for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) + ∂px = unthunk(∂px) + @test ∂px isa Tangent{typeof(p)} + @test ∂px.first isa AbstractZero + @test ∂px.second ≈ ∂x_fd + end + end +end + @testset "findnz" begin A = sprand(5, 5, 0.5) dA = similar(A) @@ -42,4 +87,4 @@ end test_rrule(logabsdet, A) test_rrule(logdet, A) test_rrule(det, A) -end \ No newline at end of file +end