From eee717ae56b424007a6f1587d21a9b5c89a7a92f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 13:22:33 -0800 Subject: [PATCH 1/8] Increment required ChainRules version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 41c407f7d..9cf587ca1 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5" ArrayLayouts = "0.1, 0.2, 0.3, 0.4" -ChainRules = "0.7.33" +ChainRules = "0.7.34" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10" ForwardDiff = "0.10" From 358d4c3909730677eca43a473a156a3d406fe35d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 13:22:42 -0800 Subject: [PATCH 2/8] Remove norm adjoint --- src/lib/array.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index e422a2365..b27929ef7 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -421,11 +421,6 @@ end end end -function _pullback(cx::AContext, ::typeof(norm), x::AbstractArray, p::Real = 2) - fallback = (x, p) -> sum(abs.(x).^p .+ eps(0f0)) ^ (one(eltype(x)) / p) # avoid d(sqrt(x))/dx == Inf at 0 - _pullback(cx, fallback, x, p) -end - # LinAlg Matrix Types # =================== From a70baf81bededecb5a0db8912af9007f7c8114f0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 13:22:51 -0800 Subject: [PATCH 3/8] Stop testing norm --- test/gradcheck.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index ff32b31e0..b5ce105c2 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1105,8 +1105,8 @@ end Y = copy(X) Δ = randn(P, P) Δ_fd = FiniteDifferences.j′vp( - FiniteDifferences.central_fdm(5, 1), - X -> pairwise(metric, X, Y; dims=2), + FiniteDifferences.central_fdm(5, 1), + X -> pairwise(metric, X, Y; dims=2), Δ, X) _, pb = Zygote.pullback(X -> pairwise(metric, X, Y; dims=2), X) @@ -1642,5 +1642,3 @@ end @test gradient(x -> sum(randexp(Random.default_rng(), Float32, (1,1))), 1) == (nothing,) end end - -@test gradient(x -> norm(x), rand(Float32, 2, 2))[1] isa Matrix{Float32} From f96a30f81a0f2e37853fea66e9d89986f87111ba Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 13:22:59 -0800 Subject: [PATCH 4/8] Increment version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9cf587ca1..e3fc60d35 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.5.15" +version = "0.5.16" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 8ac0abe5602b498d371fdeaf104376938387581c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 13:49:03 -0800 Subject: [PATCH 5/8] Revert "Stop testing norm" This reverts commit a70baf81bededecb5a0db8912af9007f7c8114f0. --- test/gradcheck.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index b5ce105c2..ff32b31e0 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1105,8 +1105,8 @@ end Y = copy(X) Δ = randn(P, P) Δ_fd = FiniteDifferences.j′vp( - FiniteDifferences.central_fdm(5, 1), - X -> pairwise(metric, X, Y; dims=2), + FiniteDifferences.central_fdm(5, 1), + X -> pairwise(metric, X, Y; dims=2), Δ, X) _, pb = Zygote.pullback(X -> pairwise(metric, X, Y; dims=2), X) @@ -1642,3 +1642,5 @@ end @test gradient(x -> sum(randexp(Random.default_rng(), Float32, (1,1))), 1) == (nothing,) end end + +@test gradient(x -> norm(x), rand(Float32, 2, 2))[1] isa Matrix{Float32} From c2808cca3d23cac0a27fa7d29fd015c9e38923ad Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 13:49:23 -0800 Subject: [PATCH 6/8] Stop testing norm adjoint --- test/gradcheck.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index ff32b31e0..7eca8c08b 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1642,5 +1642,3 @@ end @test gradient(x -> sum(randexp(Random.default_rng(), Float32, (1,1))), 1) == (nothing,) end end - -@test gradient(x -> norm(x), rand(Float32, 2, 2))[1] isa Matrix{Float32} From 3bbe71db75895cc299edda74b21e8598a9fa9a84 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 19:04:27 -0800 Subject: [PATCH 7/8] Revert "Stop testing norm adjoint" This reverts commit c2808cca3d23cac0a27fa7d29fd015c9e38923ad. --- test/gradcheck.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 7eca8c08b..ff32b31e0 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1642,3 +1642,5 @@ end @test gradient(x -> sum(randexp(Random.default_rng(), Float32, (1,1))), 1) == (nothing,) end end + +@test gradient(x -> norm(x), rand(Float32, 2, 2))[1] isa Matrix{Float32} From b9b77fee66c597636c61e1898ff92e83c455e3f6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 19:43:36 -0800 Subject: [PATCH 8/8] Add tests for norm-related issues --- test/gradcheck.jl | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index ff32b31e0..9917558f0 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1643,4 +1643,21 @@ end end end -@test gradient(x -> norm(x), rand(Float32, 2, 2))[1] isa Matrix{Float32} +@testset "norm" begin + # rrule for norm is defined in ChainRules. These tests just check various norm-related + # issues are resolved + + # check that type is not unnecessarily promoted + # https://github.com/FluxML/Zygote.jl/issues/663 + @test gradient(norm, randn(Float32, 2, 2)) isa Tuple{Matrix{Float32}} + @test gradient(norm, randn(Float32, 2, 2), 3) isa Tuple{Matrix{Float32},Float32} + @test gradient(norm, randn(Float32, 2, 2), 3f0) isa Tuple{Matrix{Float32},Float32} + @test gradient(norm, randn(ComplexF32, 2, 2), 3.5f0) isa Tuple{Matrix{ComplexF32},Float32} + + # just check that these do not error + # https://github.com/FluxML/Zygote.jl/issues/331 + gradient(x->norm(x*[1, 1]), 1.23) + gradient(x->norm(x*[1 1]), 1.23) + gradient(x->norm(x*[1im, 1]), 1.23) + gradient(x->norm(x*[1im 1]), 1.23) +end