From bc2cf6d893ba60f2992faf6f742fab7a7c43eca1 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Tue, 8 Aug 2023 09:58:20 -0700 Subject: [PATCH 1/8] start working on spdiag rrule --- Project.toml | 1 + src/rulesets/SparseArrays/sparsematrix.jl | 30 +++++++++++ test/rulesets/SparseArrays/sparsematrix.jl | 12 +++++ test/runtests.jl | 60 +++++++++++----------- 4 files changed, 73 insertions(+), 30 deletions(-) diff --git a/Project.toml b/Project.toml index d8f874245..170ad8927 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.49.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 51b41421c..4e2967464 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -49,3 +49,33 @@ function rrule(::typeof(findnz), v::AbstractSparseVector) return (I, V), findnz_pullback end + +function _spdiagm_back(p, ȳ) + k, v = p + d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + return Tangent{typeof(p)}(second = d) +end + +function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) + function diagm_pullback(Δ) + _, ȳ = unthunk(Δ) + return (NoTangent(), NoTangent(), NoTangent(), _spdiagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(m, n, kv...), diagm_pullback +end + +function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...) + function diagm_pullback(Δ) + _, ȳ = unthunk(Δ) + return (NoTangent(), _spdiagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(kv...), diagm_pullback +end + +function rrule(::typeof(spdiagm), v::AbstractVector) + function diagm_pullback(Δ) + _, ȳ = unthunk(Δ) + return (NoTangent(), diag(ȳ)) + end + return spdiagm(v), diagm_pullback +end diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index a11a1e963..0582ed159 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -18,6 +18,18 @@ end test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4) end +@testset "spdiagm" begin + @test 1 == 1 + m = 5 + n = 4 + v1 = ones(m) + v2 = ones(n) + test_rrule(spdiagm, m, n, 0 => v2) + + # test_rrule(spdiagm, 0 => v1) + # test_rrule(spdiagm, v1) +end + @testset "findnz" begin A = sprand(5, 5, 0.5) dA = similar(A) diff --git a/test/runtests.jl b/test/runtests.jl index a9f25c55c..1b9ae0456 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,41 +52,41 @@ end test_method_tables() # Check the global method tables are consistent - # Each file puts all tests inside one or more @testset blocks - include_test("rulesets/Base/base.jl") - include_test("rulesets/Base/fastmath_able.jl") - include_test("rulesets/Base/evalpoly.jl") - include_test("rulesets/Base/array.jl") - include_test("rulesets/Base/arraymath.jl") - include_test("rulesets/Base/indexing.jl") - include_test("rulesets/Base/mapreduce.jl") - include_test("rulesets/Base/sort.jl") - include_test("rulesets/Base/broadcast.jl") - - include_test("unzipped.jl") # used primarily for broadcast + # # Each file puts all tests inside one or more @testset blocks + # include_test("rulesets/Base/base.jl") + # include_test("rulesets/Base/fastmath_able.jl") + # include_test("rulesets/Base/evalpoly.jl") + # include_test("rulesets/Base/array.jl") + # include_test("rulesets/Base/arraymath.jl") + # include_test("rulesets/Base/indexing.jl") + # include_test("rulesets/Base/mapreduce.jl") + # include_test("rulesets/Base/sort.jl") + # include_test("rulesets/Base/broadcast.jl") + + # include_test("unzipped.jl") # used primarily for broadcast + + # println() + + # include_test("rulesets/Statistics/statistics.jl") + + # println() + + # include_test("rulesets/LinearAlgebra/dense.jl") + # include_test("rulesets/LinearAlgebra/norm.jl") + # include_test("rulesets/LinearAlgebra/matfun.jl") + # include_test("rulesets/LinearAlgebra/structured.jl") + # include_test("rulesets/LinearAlgebra/symmetric.jl") + # include_test("rulesets/LinearAlgebra/factorization.jl") + # include_test("rulesets/LinearAlgebra/blas.jl") + # include_test("rulesets/LinearAlgebra/lapack.jl") + # include_test("rulesets/LinearAlgebra/uniformscaling.jl") println() - include_test("rulesets/Statistics/statistics.jl") + include("rulesets/SparseArrays/sparsematrix.jl") println() - include_test("rulesets/LinearAlgebra/dense.jl") - include_test("rulesets/LinearAlgebra/norm.jl") - include_test("rulesets/LinearAlgebra/matfun.jl") - include_test("rulesets/LinearAlgebra/structured.jl") - include_test("rulesets/LinearAlgebra/symmetric.jl") - include_test("rulesets/LinearAlgebra/factorization.jl") - include_test("rulesets/LinearAlgebra/blas.jl") - include_test("rulesets/LinearAlgebra/lapack.jl") - include_test("rulesets/LinearAlgebra/uniformscaling.jl") - - println() - - include_test("rulesets/SparseArrays/sparsematrix.jl") - - println() - - include_test("rulesets/Random/random.jl") + # include_test("rulesets/Random/random.jl") println() end From 3b29953880dd4eda028f8d7c6ee8479cab87c725 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Fri, 22 Sep 2023 15:53:37 -0700 Subject: [PATCH 2/8] rrule and tests for spdiagm --- src/rulesets/SparseArrays/sparsematrix.jl | 30 ++++++++++++ test/rulesets/SparseArrays/sparsematrix.jl | 53 ++++++++++++++++++---- 2 files changed, 73 insertions(+), 10 deletions(-) diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 8ee8f8cd0..028d27fd5 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -137,3 +137,33 @@ function rrule(::typeof(det), x::SparseMatrixCSC) end return Ω, det_pullback end + + +function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) + + function spdiagm_pullback(ȳ) + return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(m, n, kv...), spdiagm_pullback +end + +function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...) + function spdiagm_pullback(ȳ) + return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(kv...), spdiagm_pullback +end + +function rrule(::typeof(spdiagm), v::AbstractVector) + function spdiagm_pullback(ȳ) + return (NoTangent(), diag(unthunk(ȳ))) + end + return spdiagm(v), spdiagm_pullback +end + + +function _diagm_back(p, ȳ) + k, v = p + d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + return Tangent{typeof(p)}(second = d) +end \ No newline at end of file diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index b5054a3d7..3e1cdb173 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -18,16 +18,49 @@ end test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4) end +# copied over from test/rulesets/LinearAlgebra/structured @testset "spdiagm" begin - @test 1 == 1 - m = 5 - n = 4 - v1 = ones(m) - v2 = ones(n) - test_rrule(spdiagm, m, n, 0 => v2) - - # test_rrule(spdiagm, 0 => v1) - # test_rrule(spdiagm, v1) + @testset "without size" begin + M, N = 7, 9 + s = (8, 8) + a, ā = randn(M), randn(M) + b, b̄ = randn(M), randn(M) + c, c̄ = randn(M - 1), randn(M - 1) + ȳ = randn(s) + ps = (0 => a, 1 => b, 0 => c) + y, back = rrule(spdiagm, ps...) + @test y == spdiagm(ps...) + ∂self, ∂pa, ∂pb, ∂pc = back(ȳ) + @test ∂self === NoTangent() + ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c) + for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) + ∂px = unthunk(∂px) + @test ∂px isa Tangent{typeof(p)} + @test ∂px.first isa AbstractZero + @test ∂px.second ≈ ∂x_fd + end + end + @testset "with size" begin + M, N = 7, 9 + a, ā = randn(M), randn(M) + b, b̄ = randn(M), randn(M) + c, c̄ = randn(M - 1), randn(M - 1) + ȳ = randn(M, N) + ps = (0 => a, 1 => b, 0 => c) + y, back = rrule(spdiagm, M, N, ps...) + @test y == spdiagm(M, N, ps...) + ∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ) + @test ∂self === NoTangent() + @test ∂M === NoTangent() + @test ∂N === NoTangent() + ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c) + for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) + ∂px = unthunk(∂px) + @test ∂px isa Tangent{typeof(p)} + @test ∂px.first isa AbstractZero + @test ∂px.second ≈ ∂x_fd + end + end end @testset "findnz" begin @@ -54,4 +87,4 @@ end test_rrule(logabsdet, A) test_rrule(logdet, A) test_rrule(det, A) -end \ No newline at end of file +end From ef72d4905d8e457056209e6d66be3e0490180012 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 12 Oct 2023 11:38:47 -0700 Subject: [PATCH 3/8] use eachindex instead of 1:length Co-authored-by: Frames White --- src/rulesets/SparseArrays/sparsematrix.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 028d27fd5..5fb7a3748 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -164,6 +164,6 @@ end function _diagm_back(p, ȳ) k, v = p - d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix return Tangent{typeof(p)}(second = d) end \ No newline at end of file From 8c3467fe1e7d7d3897df77e0ee226062d02936f1 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 12 Oct 2023 11:40:34 -0700 Subject: [PATCH 4/8] remove unused barred variables Co-authored-by: Frames White --- test/rulesets/SparseArrays/sparsematrix.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index 3e1cdb173..283452a8a 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -23,9 +23,9 @@ end @testset "without size" begin M, N = 7, 9 s = (8, 8) - a, ā = randn(M), randn(M) - b, b̄ = randn(M), randn(M) - c, c̄ = randn(M - 1), randn(M - 1) + a = randn(M) + b = randn(M) + c = randn(M - 1) ȳ = randn(s) ps = (0 => a, 1 => b, 0 => c) y, back = rrule(spdiagm, ps...) @@ -42,9 +42,9 @@ end end @testset "with size" begin M, N = 7, 9 - a, ā = randn(M), randn(M) - b, b̄ = randn(M), randn(M) - c, c̄ = randn(M - 1), randn(M - 1) + a = randn(M) + b = randn(M) + c = randn(M - 1) ȳ = randn(M, N) ps = (0 => a, 1 => b, 0 => c) y, back = rrule(spdiagm, M, N, ps...) From b7bd291f85304fa48da703cc3ef0c215041bdc4f Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Thu, 12 Oct 2023 11:49:41 -0700 Subject: [PATCH 5/8] remove duplicate _diagm_back definition --- src/rulesets/LinearAlgebra/structured.jl | 2 +- src/rulesets/SparseArrays/sparsematrix.jl | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index da153d14e..245335774 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -96,7 +96,7 @@ end function _diagm_back(p, ȳ) k, v = p - d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix return Tangent{typeof(p)}(second = d) end diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 5fb7a3748..06de7a135 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -160,10 +160,3 @@ function rrule(::typeof(spdiagm), v::AbstractVector) end return spdiagm(v), spdiagm_pullback end - - -function _diagm_back(p, ȳ) - k, v = p - d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix - return Tangent{typeof(p)}(second = d) -end \ No newline at end of file From d5502884a7da43625b522418ddcf3bdbf3cf8c22 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Thu, 12 Oct 2023 12:02:07 -0700 Subject: [PATCH 6/8] un-comment other tests --- test/runtests.jl | 58 ++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1b9ae0456..b42a1456a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,34 +52,34 @@ end test_method_tables() # Check the global method tables are consistent - # # Each file puts all tests inside one or more @testset blocks - # include_test("rulesets/Base/base.jl") - # include_test("rulesets/Base/fastmath_able.jl") - # include_test("rulesets/Base/evalpoly.jl") - # include_test("rulesets/Base/array.jl") - # include_test("rulesets/Base/arraymath.jl") - # include_test("rulesets/Base/indexing.jl") - # include_test("rulesets/Base/mapreduce.jl") - # include_test("rulesets/Base/sort.jl") - # include_test("rulesets/Base/broadcast.jl") - - # include_test("unzipped.jl") # used primarily for broadcast - - # println() - - # include_test("rulesets/Statistics/statistics.jl") - - # println() - - # include_test("rulesets/LinearAlgebra/dense.jl") - # include_test("rulesets/LinearAlgebra/norm.jl") - # include_test("rulesets/LinearAlgebra/matfun.jl") - # include_test("rulesets/LinearAlgebra/structured.jl") - # include_test("rulesets/LinearAlgebra/symmetric.jl") - # include_test("rulesets/LinearAlgebra/factorization.jl") - # include_test("rulesets/LinearAlgebra/blas.jl") - # include_test("rulesets/LinearAlgebra/lapack.jl") - # include_test("rulesets/LinearAlgebra/uniformscaling.jl") + # Each file puts all tests inside one or more @testset blocks + include_test("rulesets/Base/base.jl") + include_test("rulesets/Base/fastmath_able.jl") + include_test("rulesets/Base/evalpoly.jl") + include_test("rulesets/Base/array.jl") + include_test("rulesets/Base/arraymath.jl") + include_test("rulesets/Base/indexing.jl") + include_test("rulesets/Base/mapreduce.jl") + include_test("rulesets/Base/sort.jl") + include_test("rulesets/Base/broadcast.jl") + + include_test("unzipped.jl") # used primarily for broadcast + + println() + + include_test("rulesets/Statistics/statistics.jl") + + println() + + include_test("rulesets/LinearAlgebra/dense.jl") + include_test("rulesets/LinearAlgebra/norm.jl") + include_test("rulesets/LinearAlgebra/matfun.jl") + include_test("rulesets/LinearAlgebra/structured.jl") + include_test("rulesets/LinearAlgebra/symmetric.jl") + include_test("rulesets/LinearAlgebra/factorization.jl") + include_test("rulesets/LinearAlgebra/blas.jl") + include_test("rulesets/LinearAlgebra/lapack.jl") + include_test("rulesets/LinearAlgebra/uniformscaling.jl") println() @@ -87,6 +87,6 @@ end println() - # include_test("rulesets/Random/random.jl") + include_test("rulesets/Random/random.jl") println() end From 0be6e06c7e66a64b85da447f0cf716ad58a11a28 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 23 Oct 2023 23:51:51 +0800 Subject: [PATCH 7/8] Delete direct dep on CRTU --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7daaf90f5..8e0a07cb1 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "1.55.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" From 7a6d64870df49303d1b5a26718fda00abab8475b Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 23 Oct 2023 23:53:03 +0800 Subject: [PATCH 8/8] Using include_test to include a test --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index b42a1456a..a9f25c55c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -83,7 +83,7 @@ end println() - include("rulesets/SparseArrays/sparsematrix.jl") + include_test("rulesets/SparseArrays/sparsematrix.jl") println()