Skip to content

Commit f3ebf9e

Browse files
authored
Merge pull request #1039 from mzgubic/mz/cr1
Extra fixes for ChainRulesCore @1.0
2 parents 06aaae2 + 9153689 commit f3ebf9e

File tree

6 files changed

+22
-7
lines changed

6 files changed

+22
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2424
[compat]
2525
AbstractFFTs = "0.5, 1.0"
2626
ChainRules = "0.8.12"
27-
ChainRulesCore = "0.10.4"
27+
ChainRulesCore = "1"
2828
ChainRulesTestUtils = "0.7.1"
2929
DiffRules = "1.0"
3030
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"

src/compiler/chainrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ for T_outer in (:Tuple, :NamedTuple)
111111
# than happy.
112112
@eval @inline function wrap_chainrules_output(x::ChainRules.Tangent{P, T}) where {P, T<:$T_outer}
113113
xp = map(wrap_chainrules_output, canonicalize(x))
114-
convert($T_outer, xp)
114+
ChainRulesCore.backing(xp) # this is accessing ChainRulesCore internals, but it is prob safe enough, and it is fastest
115115
end
116116
end
117117

test/complex.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using Zygote, Test, LinearAlgebra
1818
@test gradient(x -> real(logabsdet(x)[1]), [1 2im; 3im 4])[1] [4 3im; 2im 1]/10
1919

2020
# https://github.com/FluxML/Zygote.jl/issues/705
21-
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] im .* exp.(1:3)
21+
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] real(im .* exp.(1:3))
2222
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] im .* exp.(1:3)
2323

2424
fs_C_to_R = (real,

test/features.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,12 +449,12 @@ end
449449
@test pullback(type_test)[1] == Complex{<:Real}
450450

451451
@testset "Pairs" begin
452-
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10
452+
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
453453
@test (x->10*pairs((a=x, b=2))[2])'(100) === 0
454454
foo(;kw...) = 1
455455
@test gradient(() -> foo(a=1,b=2.0)) === ()
456456

457-
@test (x->10*(x => 2)[1])'(100) === 10
457+
@test (x->10*(x => 2)[1])'(100) === 10.0
458458
@test (x->10*(x => 2)[2])'(100) === 0
459459
end
460460

test/gradcheck.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ end
8181
@test gradient(xs ->sum(xs .^ _pow), [4, -1]) == ([_pow*4^9, -10],)
8282

8383
@test gradient(x -> real((1+3im) * x^2), 5+7im) == (-32 - 44im,)
84-
@test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] (-234 + 2im)*log(5 - 7im)
84+
@test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] real((-234 + 2im)*log(5 - 7im))
8585
# D[(1+3I)x^p, p] /. {x->5+7I, p->2} // Conjugate
8686
end
8787

@@ -160,7 +160,7 @@ end
160160

161161
# https://github.com/FluxML/Zygote.jl/issues/376
162162
_, back = Zygote._pullback(x->x[1]*im, randn(2))
163-
@test back(1.0)[2] == [-im, 0]
163+
@test back(1.0)[2] == real([-im, 0]) == [0, 0]
164164

165165
# _droplike
166166
@test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],)

test/utils.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,25 @@ end
2424

2525
xs, y = randn(2,3), rand()
2626
f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments
27+
28+
# Follow is should work ones we workout what ForwardDiff should do when `Float64` is called on a `Dual`
29+
# https://github.com/JuliaDiff/ForwardDiff.jl/pull/538
30+
# else might need a custom overload of `(;;ChainRulesCore.ProjectTo)(::Dual)`
31+
# When fixed uncomment the below and delete the broken function
32+
#==
2733
dx, dy = diaghessian(f34, xs, y)
2834
@test size(dx) == size(xs)
2935
@test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs))
3036
@test dy ≈ hessian(y -> f34(xs,y), y)
37+
==#
38+
function broken()
39+
dx, dy = diaghessian(f34, xs, y) # This fails becase ProjectTo can't project a Dual onto a Float
40+
c1 = size(dx) == size(xs)
41+
c2 = vec(dx) diag(hessian(x -> f34(x,y), xs))
42+
c3 = dy hessian(y -> f34(xs,y), y)
43+
return all([c1, c2, c3])
44+
end
45+
@test_broken broken()
3146

3247
zs = randn(7,13) # test chunk mode
3348
@test length(zs) > ForwardDiff.DEFAULT_CHUNK_THRESHOLD

0 commit comments

Comments
 (0)