From bc2cf6d893ba60f2992faf6f742fab7a7c43eca1 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Tue, 8 Aug 2023 09:58:20 -0700 Subject: [PATCH 01/31] start working on spdiag rrule --- Project.toml | 1 + src/rulesets/SparseArrays/sparsematrix.jl | 30 +++++++++++ test/rulesets/SparseArrays/sparsematrix.jl | 12 +++++ test/runtests.jl | 60 +++++++++++----------- 4 files changed, 73 insertions(+), 30 deletions(-) diff --git a/Project.toml b/Project.toml index d8f874245..170ad8927 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.49.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 51b41421c..4e2967464 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -49,3 +49,33 @@ function rrule(::typeof(findnz), v::AbstractSparseVector) return (I, V), findnz_pullback end + +function _spdiagm_back(p, ȳ) + k, v = p + d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + return Tangent{typeof(p)}(second = d) +end + +function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) + function diagm_pullback(Δ) + _, ȳ = unthunk(Δ) + return (NoTangent(), NoTangent(), NoTangent(), _spdiagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(m, n, kv...), diagm_pullback +end + +function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...) + function diagm_pullback(Δ) + _, ȳ = unthunk(Δ) + return (NoTangent(), _spdiagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(kv...), diagm_pullback +end + +function rrule(::typeof(spdiagm), v::AbstractVector) + function diagm_pullback(Δ) + _, ȳ = unthunk(Δ) + return (NoTangent(), diag(ȳ)) + end + return spdiagm(v), diagm_pullback +end diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index a11a1e963..0582ed159 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -18,6 +18,18 @@ end test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4) end +@testset "spdiagm" begin + @test 1 == 1 + m = 5 + n = 4 + v1 = ones(m) + v2 = ones(n) + test_rrule(spdiagm, m, n, 0 => v2) + + # test_rrule(spdiagm, 0 => v1) + # test_rrule(spdiagm, v1) +end + @testset "findnz" begin A = sprand(5, 5, 0.5) dA = similar(A) diff --git a/test/runtests.jl b/test/runtests.jl index a9f25c55c..1b9ae0456 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,41 +52,41 @@ end test_method_tables() # Check the global method tables are consistent - # Each file puts all tests inside one or more @testset blocks - include_test("rulesets/Base/base.jl") - include_test("rulesets/Base/fastmath_able.jl") - include_test("rulesets/Base/evalpoly.jl") - include_test("rulesets/Base/array.jl") - include_test("rulesets/Base/arraymath.jl") - include_test("rulesets/Base/indexing.jl") - include_test("rulesets/Base/mapreduce.jl") - include_test("rulesets/Base/sort.jl") - include_test("rulesets/Base/broadcast.jl") - - include_test("unzipped.jl") # used primarily for broadcast + # # Each file puts all tests inside one or more @testset blocks + # include_test("rulesets/Base/base.jl") + # include_test("rulesets/Base/fastmath_able.jl") + # include_test("rulesets/Base/evalpoly.jl") + # include_test("rulesets/Base/array.jl") + # include_test("rulesets/Base/arraymath.jl") + # include_test("rulesets/Base/indexing.jl") + # include_test("rulesets/Base/mapreduce.jl") + # include_test("rulesets/Base/sort.jl") + # include_test("rulesets/Base/broadcast.jl") + + # include_test("unzipped.jl") # used primarily for broadcast + + # println() + + # include_test("rulesets/Statistics/statistics.jl") + + # println() + + # include_test("rulesets/LinearAlgebra/dense.jl") + # include_test("rulesets/LinearAlgebra/norm.jl") + # include_test("rulesets/LinearAlgebra/matfun.jl") + # include_test("rulesets/LinearAlgebra/structured.jl") + # include_test("rulesets/LinearAlgebra/symmetric.jl") + # include_test("rulesets/LinearAlgebra/factorization.jl") + # include_test("rulesets/LinearAlgebra/blas.jl") + # include_test("rulesets/LinearAlgebra/lapack.jl") + # include_test("rulesets/LinearAlgebra/uniformscaling.jl") println() - include_test("rulesets/Statistics/statistics.jl") + include("rulesets/SparseArrays/sparsematrix.jl") println() - include_test("rulesets/LinearAlgebra/dense.jl") - include_test("rulesets/LinearAlgebra/norm.jl") - include_test("rulesets/LinearAlgebra/matfun.jl") - include_test("rulesets/LinearAlgebra/structured.jl") - include_test("rulesets/LinearAlgebra/symmetric.jl") - include_test("rulesets/LinearAlgebra/factorization.jl") - include_test("rulesets/LinearAlgebra/blas.jl") - include_test("rulesets/LinearAlgebra/lapack.jl") - include_test("rulesets/LinearAlgebra/uniformscaling.jl") - - println() - - include_test("rulesets/SparseArrays/sparsematrix.jl") - - println() - - include_test("rulesets/Random/random.jl") + # include_test("rulesets/Random/random.jl") println() end From 3b29953880dd4eda028f8d7c6ee8479cab87c725 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Fri, 22 Sep 2023 15:53:37 -0700 Subject: [PATCH 02/31] rrule and tests for spdiagm --- src/rulesets/SparseArrays/sparsematrix.jl | 30 ++++++++++++ test/rulesets/SparseArrays/sparsematrix.jl | 53 ++++++++++++++++++---- 2 files changed, 73 insertions(+), 10 deletions(-) diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 8ee8f8cd0..028d27fd5 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -137,3 +137,33 @@ function rrule(::typeof(det), x::SparseMatrixCSC) end return Ω, det_pullback end + + +function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...) + + function spdiagm_pullback(ȳ) + return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(m, n, kv...), spdiagm_pullback +end + +function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...) + function spdiagm_pullback(ȳ) + return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...) + end + return spdiagm(kv...), spdiagm_pullback +end + +function rrule(::typeof(spdiagm), v::AbstractVector) + function spdiagm_pullback(ȳ) + return (NoTangent(), diag(unthunk(ȳ))) + end + return spdiagm(v), spdiagm_pullback +end + + +function _diagm_back(p, ȳ) + k, v = p + d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + return Tangent{typeof(p)}(second = d) +end \ No newline at end of file diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index b5054a3d7..3e1cdb173 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -18,16 +18,49 @@ end test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4) end +# copied over from test/rulesets/LinearAlgebra/structured @testset "spdiagm" begin - @test 1 == 1 - m = 5 - n = 4 - v1 = ones(m) - v2 = ones(n) - test_rrule(spdiagm, m, n, 0 => v2) - - # test_rrule(spdiagm, 0 => v1) - # test_rrule(spdiagm, v1) + @testset "without size" begin + M, N = 7, 9 + s = (8, 8) + a, ā = randn(M), randn(M) + b, b̄ = randn(M), randn(M) + c, c̄ = randn(M - 1), randn(M - 1) + ȳ = randn(s) + ps = (0 => a, 1 => b, 0 => c) + y, back = rrule(spdiagm, ps...) + @test y == spdiagm(ps...) + ∂self, ∂pa, ∂pb, ∂pc = back(ȳ) + @test ∂self === NoTangent() + ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c) + for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) + ∂px = unthunk(∂px) + @test ∂px isa Tangent{typeof(p)} + @test ∂px.first isa AbstractZero + @test ∂px.second ≈ ∂x_fd + end + end + @testset "with size" begin + M, N = 7, 9 + a, ā = randn(M), randn(M) + b, b̄ = randn(M), randn(M) + c, c̄ = randn(M - 1), randn(M - 1) + ȳ = randn(M, N) + ps = (0 => a, 1 => b, 0 => c) + y, back = rrule(spdiagm, M, N, ps...) + @test y == spdiagm(M, N, ps...) + ∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ) + @test ∂self === NoTangent() + @test ∂M === NoTangent() + @test ∂N === NoTangent() + ∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c) + for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd)) + ∂px = unthunk(∂px) + @test ∂px isa Tangent{typeof(p)} + @test ∂px.first isa AbstractZero + @test ∂px.second ≈ ∂x_fd + end + end end @testset "findnz" begin @@ -54,4 +87,4 @@ end test_rrule(logabsdet, A) test_rrule(logdet, A) test_rrule(det, A) -end \ No newline at end of file +end From 15a46e2cbefaf2f8f6b07caaf59680e3fdb895f7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Sep 2023 20:22:57 +0000 Subject: [PATCH 03/31] Bump actions/checkout from 4.0.0 to 4.1.0 Bumps [actions/checkout](https://github.com/actions/checkout) from 4.0.0 to 4.1.0. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v4.0.0...v4.1.0) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/CI.yml | 2 +- .github/workflows/IntegrationTest.yml | 4 ++-- .github/workflows/JuliaNightly.yml | 2 +- .github/workflows/VersionVigilante_pull_request.yml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5e31ddcdb..6549360ca 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - x86 - x64 steps: - - uses: actions/checkout@v4.0.0 + - uses: actions/checkout@v4.1.0 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index f49c7cddf..a03183a73 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -25,14 +25,14 @@ jobs: # package: {user: JuliaDiff, repo: Diffractor.jl} steps: - - uses: actions/checkout@v4.0.0 + - uses: actions/checkout@v4.1.0 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - uses: julia-actions/julia-buildpkg@latest - name: Clone Downstream - uses: actions/checkout@v4.0.0 + uses: actions/checkout@v4.1.0 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream diff --git a/.github/workflows/JuliaNightly.yml b/.github/workflows/JuliaNightly.yml index 0d3526c0b..16b28983a 100644 --- a/.github/workflows/JuliaNightly.yml +++ b/.github/workflows/JuliaNightly.yml @@ -23,7 +23,7 @@ jobs: - x86 - x64 steps: - - uses: actions/checkout@v4.0.0 + - uses: actions/checkout@v4.1.0 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/.github/workflows/VersionVigilante_pull_request.yml b/.github/workflows/VersionVigilante_pull_request.yml index 76fffeac4..da513155e 100644 --- a/.github/workflows/VersionVigilante_pull_request.yml +++ b/.github/workflows/VersionVigilante_pull_request.yml @@ -6,7 +6,7 @@ jobs: VersionVigilante: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4.0.0 + - uses: actions/checkout@v4.1.0 - uses: julia-actions/setup-julia@latest - name: VersionVigilante.main id: versionvigilante_main From bea1b810a608e2984b55b26bfcd2587e68ea3793 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 21 Sep 2023 16:04:13 +0800 Subject: [PATCH 04/31] add frules for getfield --- src/rulesets/Base/indexing.jl | 9 ++++++--- test/rulesets/Base/indexing.jl | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 1334cc925..37ed8ca48 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -1,6 +1,10 @@ # Int rather than Int64/Integer is intentional -function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int) - return x.i, ẋ.i +function ChainRulesCore.frule((_, Δ, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}) + return (getfield(strct, sym), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym)) +end + +function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}, inbounds) + return (getfield(strct, sym, inbounds), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym)) end "for a given tuple type, returns a Val{N} where N is the length of the tuple" @@ -21,7 +25,6 @@ function rrule(::typeof(getindex), x::T, i::Integer) where {T<:NTuple{<:Any,<:Nu dx = ntuple(j -> j == i ? dy : zero(dy), _tuple_N(T)) return (NoTangent(), Tangent{T}(dx...), NoTangent()) end - return x[i], getindex_back_2 end # Note Zygote has getindex(::Tuple, ::UnitRange) separately from getindex(::Tuple, ::AbstractVector), diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index d3c7ecfb4..c21bb8425 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -1,3 +1,17 @@ +@testset "getfield" begin + struct Foo + x::Float64 + y::Float64 + end + test_frule(getfield, Foo(1.5, 2.5), :x, check_inferred=false) + + test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false) + test_frule(getfield, (; a=1.5, b=2.5), 2) + + test_frule(getfield, (1.5, 2.5), 2) + test_frule(getfield, (1.5, 2.5), 2, true) +end + @testset "getindex" begin @testset "getindex(::Tuple, ...)" begin x = (1.2, 3.4, 5.6) From 84cd7be44fb0493aeb18168bf6aca59fc5885d65 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 26 Sep 2023 22:14:11 +0800 Subject: [PATCH 05/31] move struct to top level --- test/rulesets/Base/indexing.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index c21bb8425..a677df3b9 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -1,9 +1,11 @@ +struct FooTwoField + x::Float64 + y::Float64 +end + + @testset "getfield" begin - struct Foo - x::Float64 - y::Float64 - end - test_frule(getfield, Foo(1.5, 2.5), :x, check_inferred=false) + test_frule(getfield, FooTwoField(1.5, 2.5), :x, check_inferred=false) test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false) test_frule(getfield, (; a=1.5, b=2.5), 2) From a76e0ae4a536e1e35d3024d50a940cf7c8a78df2 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Sep 2023 14:49:14 +0800 Subject: [PATCH 06/31] undo mistakenly deleted line --- src/rulesets/Base/indexing.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 37ed8ca48..2f5e6cf79 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -25,6 +25,7 @@ function rrule(::typeof(getindex), x::T, i::Integer) where {T<:NTuple{<:Any,<:Nu dx = ntuple(j -> j == i ? dy : zero(dy), _tuple_N(T)) return (NoTangent(), Tangent{T}(dx...), NoTangent()) end + return x[i], getindex_back_2 end # Note Zygote has getindex(::Tuple, ::UnitRange) separately from getindex(::Tuple, ::AbstractVector), From 5316136f1cb1711108c5f10c21b63b6eab1560a1 Mon Sep 17 00:00:00 2001 From: hyrodium Date: Sun, 1 Oct 2023 20:30:32 +0900 Subject: [PATCH 07/31] copy JuliaFormatter config from ChainRulesCore.jl --- .JuliaFormatter.toml | 1 + .github/workflows/format.yml | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 .JuliaFormatter.toml create mode 100644 .github/workflows/format.yml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 000000000..323237bab --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "blue" diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 000000000..f6f268c0e --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,27 @@ +name: Format suggestions + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: 1 + - run: | + julia -e 'using Pkg; Pkg.add("JuliaFormatter")' + julia -e 'using JuliaFormatter; format("."; verbose=true)' + - uses: reviewdog/action-suggester@v1 + with: + tool_name: JuliaFormatter + fail_on_error: true + filter_mode: added From e2781569b670ed53c3884b7bb6c14e8bc23a6d47 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Oct 2023 20:49:30 +0000 Subject: [PATCH 08/31] Bump actions/checkout from 2 to 4 Bumps [actions/checkout](https://github.com/actions/checkout) from 2 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v2...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/CI.yml | 2 +- .github/workflows/IntegrationTest.yml | 4 ++-- .github/workflows/JuliaNightly.yml | 2 +- .github/workflows/VersionVigilante_pull_request.yml | 2 +- .github/workflows/format.yml | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6549360ca..dc73a1d7a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - x86 - x64 steps: - - uses: actions/checkout@v4.1.0 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index a03183a73..c63b657c1 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -25,14 +25,14 @@ jobs: # package: {user: JuliaDiff, repo: Diffractor.jl} steps: - - uses: actions/checkout@v4.1.0 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - uses: julia-actions/julia-buildpkg@latest - name: Clone Downstream - uses: actions/checkout@v4.1.0 + uses: actions/checkout@v4 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream diff --git a/.github/workflows/JuliaNightly.yml b/.github/workflows/JuliaNightly.yml index 16b28983a..a1281c5f6 100644 --- a/.github/workflows/JuliaNightly.yml +++ b/.github/workflows/JuliaNightly.yml @@ -23,7 +23,7 @@ jobs: - x86 - x64 steps: - - uses: actions/checkout@v4.1.0 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} diff --git a/.github/workflows/VersionVigilante_pull_request.yml b/.github/workflows/VersionVigilante_pull_request.yml index da513155e..57ee668e3 100644 --- a/.github/workflows/VersionVigilante_pull_request.yml +++ b/.github/workflows/VersionVigilante_pull_request.yml @@ -6,7 +6,7 @@ jobs: VersionVigilante: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4.1.0 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@latest - name: VersionVigilante.main id: versionvigilante_main diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index f6f268c0e..f80377a24 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -13,7 +13,7 @@ jobs: format: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@latest with: version: 1 From 8e80b16f4cd749efb8773e595abd29d371ee7113 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 3 Oct 2023 18:36:25 +0800 Subject: [PATCH 09/31] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8dfefc100..8e0a07cb1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.54.0" +version = "1.55.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 4fd2e5849cf3ec58070eae7453aeec7fd4f000f1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 20:43:42 +0000 Subject: [PATCH 10/31] Bump styfle/cancel-workflow-action from 0.9.0 to 0.12.0 Bumps [styfle/cancel-workflow-action](https://github.com/styfle/cancel-workflow-action) from 0.9.0 to 0.12.0. - [Release notes](https://github.com/styfle/cancel-workflow-action/releases) - [Commits](https://github.com/styfle/cancel-workflow-action/compare/0.9.0...0.12.0) --- updated-dependencies: - dependency-name: styfle/cancel-workflow-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/Cancel.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Cancel.yml b/.github/workflows/Cancel.yml index 85b1ef3d2..652e014a9 100644 --- a/.github/workflows/Cancel.yml +++ b/.github/workflows/Cancel.yml @@ -13,7 +13,7 @@ jobs: cancel: runs-on: ubuntu-latest steps: - - uses: styfle/cancel-workflow-action@0.9.0 + - uses: styfle/cancel-workflow-action@0.12.0 with: all_but_latest: true workflow_id: ${{ github.event.workflow.id }} From ef72d4905d8e457056209e6d66be3e0490180012 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 12 Oct 2023 11:38:47 -0700 Subject: [PATCH 11/31] use eachindex instead of 1:length Co-authored-by: Frames White --- src/rulesets/SparseArrays/sparsematrix.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 028d27fd5..5fb7a3748 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -164,6 +164,6 @@ end function _diagm_back(p, ȳ) k, v = p - d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix return Tangent{typeof(p)}(second = d) end \ No newline at end of file From 8c3467fe1e7d7d3897df77e0ee226062d02936f1 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 12 Oct 2023 11:40:34 -0700 Subject: [PATCH 12/31] remove unused barred variables Co-authored-by: Frames White --- test/rulesets/SparseArrays/sparsematrix.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index 3e1cdb173..283452a8a 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -23,9 +23,9 @@ end @testset "without size" begin M, N = 7, 9 s = (8, 8) - a, ā = randn(M), randn(M) - b, b̄ = randn(M), randn(M) - c, c̄ = randn(M - 1), randn(M - 1) + a = randn(M) + b = randn(M) + c = randn(M - 1) ȳ = randn(s) ps = (0 => a, 1 => b, 0 => c) y, back = rrule(spdiagm, ps...) @@ -42,9 +42,9 @@ end end @testset "with size" begin M, N = 7, 9 - a, ā = randn(M), randn(M) - b, b̄ = randn(M), randn(M) - c, c̄ = randn(M - 1), randn(M - 1) + a = randn(M) + b = randn(M) + c = randn(M - 1) ȳ = randn(M, N) ps = (0 => a, 1 => b, 0 => c) y, back = rrule(spdiagm, M, N, ps...) From b7bd291f85304fa48da703cc3ef0c215041bdc4f Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Thu, 12 Oct 2023 11:49:41 -0700 Subject: [PATCH 13/31] remove duplicate _diagm_back definition --- src/rulesets/LinearAlgebra/structured.jl | 2 +- src/rulesets/SparseArrays/sparsematrix.jl | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index da153d14e..245335774 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -96,7 +96,7 @@ end function _diagm_back(p, ȳ) k, v = p - d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix + d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix return Tangent{typeof(p)}(second = d) end diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 5fb7a3748..06de7a135 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -160,10 +160,3 @@ function rrule(::typeof(spdiagm), v::AbstractVector) end return spdiagm(v), spdiagm_pullback end - - -function _diagm_back(p, ȳ) - k, v = p - d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix - return Tangent{typeof(p)}(second = d) -end \ No newline at end of file From d5502884a7da43625b522418ddcf3bdbf3cf8c22 Mon Sep 17 00:00:00 2001 From: Sam Urmy Date: Thu, 12 Oct 2023 12:02:07 -0700 Subject: [PATCH 14/31] un-comment other tests --- test/runtests.jl | 58 ++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1b9ae0456..b42a1456a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,34 +52,34 @@ end test_method_tables() # Check the global method tables are consistent - # # Each file puts all tests inside one or more @testset blocks - # include_test("rulesets/Base/base.jl") - # include_test("rulesets/Base/fastmath_able.jl") - # include_test("rulesets/Base/evalpoly.jl") - # include_test("rulesets/Base/array.jl") - # include_test("rulesets/Base/arraymath.jl") - # include_test("rulesets/Base/indexing.jl") - # include_test("rulesets/Base/mapreduce.jl") - # include_test("rulesets/Base/sort.jl") - # include_test("rulesets/Base/broadcast.jl") - - # include_test("unzipped.jl") # used primarily for broadcast - - # println() - - # include_test("rulesets/Statistics/statistics.jl") - - # println() - - # include_test("rulesets/LinearAlgebra/dense.jl") - # include_test("rulesets/LinearAlgebra/norm.jl") - # include_test("rulesets/LinearAlgebra/matfun.jl") - # include_test("rulesets/LinearAlgebra/structured.jl") - # include_test("rulesets/LinearAlgebra/symmetric.jl") - # include_test("rulesets/LinearAlgebra/factorization.jl") - # include_test("rulesets/LinearAlgebra/blas.jl") - # include_test("rulesets/LinearAlgebra/lapack.jl") - # include_test("rulesets/LinearAlgebra/uniformscaling.jl") + # Each file puts all tests inside one or more @testset blocks + include_test("rulesets/Base/base.jl") + include_test("rulesets/Base/fastmath_able.jl") + include_test("rulesets/Base/evalpoly.jl") + include_test("rulesets/Base/array.jl") + include_test("rulesets/Base/arraymath.jl") + include_test("rulesets/Base/indexing.jl") + include_test("rulesets/Base/mapreduce.jl") + include_test("rulesets/Base/sort.jl") + include_test("rulesets/Base/broadcast.jl") + + include_test("unzipped.jl") # used primarily for broadcast + + println() + + include_test("rulesets/Statistics/statistics.jl") + + println() + + include_test("rulesets/LinearAlgebra/dense.jl") + include_test("rulesets/LinearAlgebra/norm.jl") + include_test("rulesets/LinearAlgebra/matfun.jl") + include_test("rulesets/LinearAlgebra/structured.jl") + include_test("rulesets/LinearAlgebra/symmetric.jl") + include_test("rulesets/LinearAlgebra/factorization.jl") + include_test("rulesets/LinearAlgebra/blas.jl") + include_test("rulesets/LinearAlgebra/lapack.jl") + include_test("rulesets/LinearAlgebra/uniformscaling.jl") println() @@ -87,6 +87,6 @@ end println() - # include_test("rulesets/Random/random.jl") + include_test("rulesets/Random/random.jl") println() end From 0be6e06c7e66a64b85da447f0cf716ad58a11a28 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 23 Oct 2023 23:51:51 +0800 Subject: [PATCH 15/31] Delete direct dep on CRTU --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7daaf90f5..8e0a07cb1 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "1.55.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" From 7a6d64870df49303d1b5a26718fda00abab8475b Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 23 Oct 2023 23:53:03 +0800 Subject: [PATCH 16/31] Using include_test to include a test --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index b42a1456a..a9f25c55c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -83,7 +83,7 @@ end println() - include("rulesets/SparseArrays/sparsematrix.jl") + include_test("rulesets/SparseArrays/sparsematrix.jl") println() From b9046e73149809747129f6899526309f741a3886 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 24 Oct 2023 11:23:14 +0800 Subject: [PATCH 17/31] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8e0a07cb1..b3d509c2b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.55.0" +version = "1.56.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From b8adca6ef3526ce6b562b8f9b2c022e7581f54f7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sun, 29 Oct 2023 10:37:12 -0400 Subject: [PATCH 18/31] CompatHelper: add new compat entry for Statistics at version 1, (keep existing compat) (#748) --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index b3d509c2b..7f3b0ba04 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,7 @@ JuliaInterpreter = "0.8,0.9" RealDot = "0.1" SparseInverseSubset = "0.1" StaticArrays = "1.2" +Statistics = "1" StructArrays = "0.6.11" julia = "1.6" From 6d8616d9924da957f2c1cb98b61f0d87d15110c6 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 3 Nov 2023 16:14:35 +0800 Subject: [PATCH 19/31] Add rule for with_logger --- src/ChainRules.jl | 1 + src/rulesets/Base/CoreLogging.jl | 20 ++++++++++++++++++++ src/rulesets/Base/nondiff.jl | 4 ---- test/rulesets/Base/CoreLogging.jl | 11 +++++++++++ test/runtests.jl | 1 + 5 files changed, 33 insertions(+), 4 deletions(-) create mode 100644 src/rulesets/Base/CoreLogging.jl create mode 100644 test/rulesets/Base/CoreLogging.jl diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 28e73c166..6d33a22e7 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -43,6 +43,7 @@ include("rulesets/Base/indexing.jl") include("rulesets/Base/sort.jl") include("rulesets/Base/mapreduce.jl") include("rulesets/Base/broadcast.jl") +include("rulesets/Base/CoreLogging.jl") include("rulesets/Distributed/nondiff.jl") diff --git a/src/rulesets/Base/CoreLogging.jl b/src/rulesets/Base/CoreLogging.jl new file mode 100644 index 000000000..fafcb8314 --- /dev/null +++ b/src/rulesets/Base/CoreLogging.jl @@ -0,0 +1,20 @@ +# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib) + +function rrule( + rc::RuleConfig{>:ChainRulesCore.HasReverseMode}, + ::typeof(Base.CoreLogging.with_logger), + f::Function, + logger::Base.CoreLogging.AbstractLogger +) + y, f_pb = Base.CoreLogging.with_logger(logger) do + rrule_via_ad(rc, f) + end + with_logger_pullback(ȳ) = (NoTangent(), only(f_pb(ȳ)), NoTangent()) + return y, with_logger_pullback +end + +@non_differentiable Base.CoreLogging.current_logger(args...) +@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...) +@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...) +@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any) +@non_differentiable Base.CoreLogging.handle_message(::Any...) \ No newline at end of file diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 58298f068..d35024163 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -483,10 +483,6 @@ end @non_differentiable Broadcast.result_style(::Any) @non_differentiable Broadcast.result_style(::Any, ::Any) -@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...) -@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...) -@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any) -@non_differentiable Base.CoreLogging.handle_message(::Any...) @non_differentiable Libc.free(::Any) @non_differentiable Libc.getpid() diff --git a/test/rulesets/Base/CoreLogging.jl b/test/rulesets/Base/CoreLogging.jl new file mode 100644 index 000000000..d54806587 --- /dev/null +++ b/test/rulesets/Base/CoreLogging.jl @@ -0,0 +1,11 @@ +# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib) +@testset "CoreLogging.jl" begin + @testset "with_logger" begin + test_rrule( + Base.CoreLogging.with_logger, + ()->2.0 * 3.0, + Base.CoreLogging.NullLogger(); + check_inferred=false + ) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index a9f25c55c..367a63de1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ end test_method_tables() # Check the global method tables are consistent # Each file puts all tests inside one or more @testset blocks + include_test("rulesets/CoreLogging.jl") include_test("rulesets/Base/base.jl") include_test("rulesets/Base/fastmath_able.jl") include_test("rulesets/Base/evalpoly.jl") From 8613c5ec633ac379e1bc7f0979678d4be499b4a2 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 3 Nov 2023 16:17:10 +0800 Subject: [PATCH 20/31] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7f3b0ba04..1c7ebb77e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.56.0" +version = "1.57.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 613d9d23cfac87eb5802283da77c9771e476c8ab Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 3 Nov 2023 16:26:24 +0800 Subject: [PATCH 21/31] style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rulesets/Base/CoreLogging.jl | 2 +- test/rulesets/Base/CoreLogging.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/CoreLogging.jl b/src/rulesets/Base/CoreLogging.jl index fafcb8314..ac05ed4b8 100644 --- a/src/rulesets/Base/CoreLogging.jl +++ b/src/rulesets/Base/CoreLogging.jl @@ -4,7 +4,7 @@ function rrule( rc::RuleConfig{>:ChainRulesCore.HasReverseMode}, ::typeof(Base.CoreLogging.with_logger), f::Function, - logger::Base.CoreLogging.AbstractLogger + logger::Base.CoreLogging.AbstractLogger, ) y, f_pb = Base.CoreLogging.with_logger(logger) do rrule_via_ad(rc, f) diff --git a/test/rulesets/Base/CoreLogging.jl b/test/rulesets/Base/CoreLogging.jl index d54806587..7528df078 100644 --- a/test/rulesets/Base/CoreLogging.jl +++ b/test/rulesets/Base/CoreLogging.jl @@ -3,9 +3,9 @@ @testset "with_logger" begin test_rrule( Base.CoreLogging.with_logger, - ()->2.0 * 3.0, + () -> 2.0 * 3.0, Base.CoreLogging.NullLogger(); - check_inferred=false + check_inferred=false, ) end end \ No newline at end of file From 7e2ea02881256650dae2f9d5482e0b106675faa1 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 3 Nov 2023 16:28:20 +0800 Subject: [PATCH 22/31] new lines on end of files --- src/rulesets/Base/CoreLogging.jl | 2 +- test/rulesets/Base/CoreLogging.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/CoreLogging.jl b/src/rulesets/Base/CoreLogging.jl index ac05ed4b8..ae97f4e40 100644 --- a/src/rulesets/Base/CoreLogging.jl +++ b/src/rulesets/Base/CoreLogging.jl @@ -17,4 +17,4 @@ end @non_differentiable Base.CoreLogging.current_logger_for_env(::Any...) @non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...) @non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any) -@non_differentiable Base.CoreLogging.handle_message(::Any...) \ No newline at end of file +@non_differentiable Base.CoreLogging.handle_message(::Any...) diff --git a/test/rulesets/Base/CoreLogging.jl b/test/rulesets/Base/CoreLogging.jl index 7528df078..28c0b74a8 100644 --- a/test/rulesets/Base/CoreLogging.jl +++ b/test/rulesets/Base/CoreLogging.jl @@ -8,4 +8,4 @@ check_inferred=false, ) end -end \ No newline at end of file +end From 99b814178e1d99b857477784348a703265843d52 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 3 Nov 2023 16:33:02 +0800 Subject: [PATCH 23/31] fix path --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 367a63de1..768f7c208 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,7 +53,7 @@ end test_method_tables() # Check the global method tables are consistent # Each file puts all tests inside one or more @testset blocks - include_test("rulesets/CoreLogging.jl") + include_test("rulesets/Base/CoreLogging.jl") include_test("rulesets/Base/base.jl") include_test("rulesets/Base/fastmath_able.jl") include_test("rulesets/Base/evalpoly.jl") From c92bf70ad0f2277154b28d0b353c5c382d621973 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 6 Nov 2023 12:47:06 +0800 Subject: [PATCH 24/31] Add version bounds for stdlibs --- Project.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Project.toml b/Project.toml index 1c7ebb77e..7bf2422e3 100644 --- a/Project.toml +++ b/Project.toml @@ -23,16 +23,20 @@ Adapt = "3.4.0" ChainRulesCore = "1.15.3" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" +Distributed = "1" FiniteDifferences = "0.12.20" GPUArraysCore = "0.1.0" IrrationalConstants = "0.1.1, 0.2" JLArrays = "0.1" JuliaInterpreter = "0.8,0.9" +LinearAlgebra = "1" +Random = "1" RealDot = "0.1" SparseInverseSubset = "0.1" StaticArrays = "1.2" Statistics = "1" StructArrays = "0.6.11" +SuiteSparse = "1" julia = "1.6" [extras] From 01ea92b83d3f7b530303942eaa9e7df89028fa97 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 6 Nov 2023 21:42:42 +0800 Subject: [PATCH 25/31] bound sparse arrays --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 7bf2422e3..31d77496a 100644 --- a/Project.toml +++ b/Project.toml @@ -33,6 +33,7 @@ LinearAlgebra = "1" Random = "1" RealDot = "0.1" SparseInverseSubset = "0.1" +SparseArrays = "1" StaticArrays = "1.2" Statistics = "1" StructArrays = "0.6.11" From e05886a9a64d2f862332e24023a60c6a8a762615 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Sun, 9 Jul 2023 01:14:51 +0000 Subject: [PATCH 26/31] Don't use the array muladd rule for ZeroTangent --- src/rulesets/Base/arraymath.jl | 6 +++--- src/rulesets/Base/base.jl | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 7fbf46062..078bb602a 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -351,7 +351,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) - + Ȳf = Ȳ @static if VERSION >= v"1.9" # Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358 @@ -360,7 +360,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R end end Yf = Y - @static if VERSION >= v"1.9" + @static if VERSION >= v"1.9" # Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358 if !isa(Y, AbstractArray) Yf = [Y] @@ -371,7 +371,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R B̄ = A' \ Ȳf Ā = -B̄ * Y' t = (B - A * Y) * B̄' - @static if VERSION >= v"1.9" + @static if VERSION >= v"1.9" # Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358 if !isa(t, AbstractArray) t = [t] diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 9576abd98..28cc11d19 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -94,6 +94,7 @@ end @scalar_rule fma(x, y, z) (y, x, true) @scalar_rule muladd(x, y, z) (y, x, true) +@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true) @scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent()) @scalar_rule( mod(x, y), From 47479f458ec798f824a4dc1806f196c39b2abd2c Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 10 Nov 2023 14:18:51 +0800 Subject: [PATCH 27/31] test muladd mixing numbers and zerotangents --- test/rulesets/Base/base.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 36452da1e..7d06b1421 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -153,6 +153,17 @@ test_rrule(muladd, 10randn(), randn(), randn()) end + @testset "muladd ZeroTangent" begin + test_frule(muladd, 2.0, 3.0, ZeroTangent()) + test_frule(muladd, 2.0, ZeroTangent(), 4.0) + test_frule(muladd, ZeroTangent(), 3.0, 4.0) + + test_rrule(muladd, 2.0, 3.0, ZeroTangent()) + test_rrule(muladd, 2.0, ZeroTangent(), 4.0) + test_rrule(muladd, ZeroTangent(), 3.0, 4.0) + end + + @testset "fma" begin test_frule(fma, 10randn(), randn(), randn()) test_rrule(fma, 10randn(), randn(), randn()) From 81c6e8ca011c680aaefdc1396e15bb10553e2116 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 10 Nov 2023 14:44:52 +0800 Subject: [PATCH 28/31] style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/rulesets/Base/base.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 7d06b1421..9a5278747 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -156,14 +156,13 @@ @testset "muladd ZeroTangent" begin test_frule(muladd, 2.0, 3.0, ZeroTangent()) test_frule(muladd, 2.0, ZeroTangent(), 4.0) - test_frule(muladd, ZeroTangent(), 3.0, 4.0) + test_frule(muladd, ZeroTangent(), 3.0, 4.0) test_rrule(muladd, 2.0, 3.0, ZeroTangent()) test_rrule(muladd, 2.0, ZeroTangent(), 4.0) - test_rrule(muladd, ZeroTangent(), 3.0, 4.0) + test_rrule(muladd, ZeroTangent(), 3.0, 4.0) end - @testset "fma" begin test_frule(fma, 10randn(), randn(), randn()) test_rrule(fma, 10randn(), randn(), randn()) From 87f49961e57b368e3f0afaa697f64a3f98f638c0 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 10 Nov 2023 16:04:58 +0800 Subject: [PATCH 29/31] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 31d77496a..82f4d6616 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.57.0" +version = "1.58.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 40b9058c1a6798a4f72bda1227786d584b44c822 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 1 Dec 2023 16:15:44 +1100 Subject: [PATCH 30/31] Allow unzip to return JLArray (#759) * Allow unzip to return JLArray * style * style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * extend comment Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- test/unzipped.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/unzipped.jl b/test/unzipped.jl index 97aaa23f5..4215f3a6e 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -87,11 +87,14 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map # TODO invent some tests of this rrule's pullback function @test unzip(jl([(1,2), (3,4), (5,6)])) == (jl([1, 3, 5]), jl([2, 4, 6])) - @test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] == jl([2, 4, 6]) - @test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] isa Base.ReinterpretArray - @test unzip(jl([(1,), (3,), (5,)]))[1] == jl([1, 3, 5]) - @test unzip(jl([(1,), (3,), (5,)]))[1] isa Base.ReinterpretArray + + # depending on Julia/package versions, may get ReinterpretArray or JLArray + # Either is acceptable + @test isa( + unzip(jl([(missing, 2), (missing, 4), (missing, 6)]))[2], + Union{Base.ReinterpretArray,JLArray}, + ) end -end \ No newline at end of file +end From ae37562ea1f16816a0d8fff24e0aca6cd594a40f Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Thu, 14 Dec 2023 00:58:46 -0500 Subject: [PATCH 31/31] Allow single indexing of arrays of GPU arrays (#760) * Allow single indexing of arrays of GPU arrays * bump version --- Project.toml | 2 +- src/rulesets/Base/indexing.jl | 2 +- test/rulesets/Base/indexing.jl | 8 ++++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 82f4d6616..77f514046 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.58.0" +version = "1.58.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 2f5e6cf79..830571ecd 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -144,7 +144,7 @@ end ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...) function ∇getindex!(dx::AbstractArray, dy, inds::Integer...) - view(dx, inds...) .+= Ref(dy) + @views dx[inds...] += dy return dx end function ∇getindex!(dx::AbstractArray, dy, inds...) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index a677df3b9..e878dd061 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -177,6 +177,14 @@ end @test Array(y3) == Array(x_23_gpu)[1, [1,1,2]] @test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0]) end + + @testset "getindex(::Array{<:AbstractGPUArray})" begin + x_gpu = jl(rand(1)) + y, back = rrule(getindex, [x_gpu], 1) + @test y === x_gpu + dxs_gpu = unthunk(back(jl([1.0]))[2]) + @test dxs_gpu == [jl([1.0])] + end end # first & tail handled by getfield rules