Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rrule for spdiagm #740

Merged
merged 10 commits into from
Oct 24, 2023
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.55.0"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Expand Down
2 changes: 1 addition & 1 deletion src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ end

function _diagm_back(p, ȳ)
k, v = p
d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix
d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix
return Tangent{typeof(p)}(second = d)
end

Expand Down
23 changes: 23 additions & 0 deletions src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,26 @@
end
return Ω, det_pullback
end


Check warning on line 141 in src/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/rulesets/SparseArrays/sparsematrix.jl:141:- src/rulesets/SparseArrays/sparsematrix.jl:142:-function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) src/rulesets/SparseArrays/sparsematrix.jl:143:- src/rulesets/SparseArrays/sparsematrix.jl:152:+function rrule( src/rulesets/SparseArrays/sparsematrix.jl:153:+ ::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}... src/rulesets/SparseArrays/sparsematrix.jl:154:+)
function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)

function spdiagm_pullback(ȳ)
return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
end
return spdiagm(m, n, kv...), spdiagm_pullback
end

function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...)
function spdiagm_pullback(ȳ)
return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
end
return spdiagm(kv...), spdiagm_pullback
end

function rrule(::typeof(spdiagm), v::AbstractVector)
function spdiagm_pullback(ȳ)
return (NoTangent(), diag(unthunk(ȳ)))
end
return spdiagm(v), spdiagm_pullback
end
47 changes: 46 additions & 1 deletion test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,51 @@
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4)
end

# copied over from test/rulesets/LinearAlgebra/structured
@testset "spdiagm" begin
@testset "without size" begin
M, N = 7, 9
s = (8, 8)
a = randn(M)
b = randn(M)
c = randn(M - 1)
ȳ = randn(s)
ps = (0 => a, 1 => b, 0 => c)
y, back = rrule(spdiagm, ps...)
ElOceanografo marked this conversation as resolved.
Show resolved Hide resolved
@test y == spdiagm(ps...)
∂self, ∂pa, ∂pb, ∂pc = back(ȳ)
@test ∂self === NoTangent()
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c)

Check warning on line 35 in test/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: test/rulesets/SparseArrays/sparsematrix.jl:35:- ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c) test/rulesets/SparseArrays/sparsematrix.jl:35:+ ∂a_fd, ∂b_fd, ∂c_fd = j′vp( test/rulesets/SparseArrays/sparsematrix.jl:36:+ _fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c test/rulesets/SparseArrays/sparsematrix.jl:37:+ )
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
∂px = unthunk(∂px)
@test ∂px isa Tangent{typeof(p)}
@test ∂px.first isa AbstractZero
@test ∂px.second ≈ ∂x_fd
end
end
@testset "with size" begin
M, N = 7, 9
a = randn(M)
b = randn(M)
c = randn(M - 1)
ȳ = randn(M, N)
ps = (0 => a, 1 => b, 0 => c)
y, back = rrule(spdiagm, M, N, ps...)
@test y == spdiagm(M, N, ps...)
∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ)
@test ∂self === NoTangent()
@test ∂M === NoTangent()
@test ∂N === NoTangent()
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c)

Check warning on line 56 in test/rulesets/SparseArrays/sparsematrix.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: test/rulesets/SparseArrays/sparsematrix.jl:56:- ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c) test/rulesets/SparseArrays/sparsematrix.jl:58:+ ∂a_fd, ∂b_fd, ∂c_fd = j′vp( test/rulesets/SparseArrays/sparsematrix.jl:59:+ _fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c test/rulesets/SparseArrays/sparsematrix.jl:60:+ )
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
∂px = unthunk(∂px)
@test ∂px isa Tangent{typeof(p)}
@test ∂px.first isa AbstractZero
@test ∂px.second ≈ ∂x_fd
end
end
end

@testset "findnz" begin
A = sprand(5, 5, 0.5)
dA = similar(A)
Expand All @@ -42,4 +87,4 @@
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ end

ElOceanografo marked this conversation as resolved.
Show resolved Hide resolved
println()

include_test("rulesets/SparseArrays/sparsematrix.jl")
include("rulesets/SparseArrays/sparsematrix.jl")
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

println()

Expand Down
Loading