Skip to content

Fix errors revealed by Zygote's tests #175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.5.0"
version = "0.5.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.7"
ChainRulesTestUtils = "0.2.1"
ChainRulesTestUtils = "0.2.2"
Compat = "3"
FiniteDifferences = "0.9"
Reexport = "0.2"
Expand Down
22 changes: 18 additions & 4 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

@scalar_rule(abs(x::Real), sign(x))
@scalar_rule(abs2(x), 2x)
@scalar_rule(exp(x), Ω)
@scalar_rule(exp(x::Real), Ω)
@scalar_rule(exp10(x), Ω * log(oftype(x, 10)))
@scalar_rule(exp2(x), Ω * log(oftype(x, 2)))
@scalar_rule(expm1(x), exp(x))
Expand Down Expand Up @@ -45,7 +45,8 @@
@scalar_rule(sinh(x), cosh(x))
@scalar_rule(tanh(x), 1-Ω^2)

@scalar_rule(acosh(x), inv(sqrt(x^2 - 1)))
# Can't multiply though sqrt in acosh because of negative complex case for x
@scalar_rule(acosh(x), inv(sqrt(x - 1) * sqrt(x + 1)))
@scalar_rule(acoth(x), inv(1 - x^2))
@scalar_rule(acsch(x), -inv(x^2 * sqrt(1 + x^-2)))
@scalar_rule(acsch(x::Real), -inv(abs(x) * sqrt(1 + x^2)))
Expand All @@ -66,7 +67,9 @@
@scalar_rule(-(x, y), (One(), -1))
@scalar_rule(/(x, y), (inv(y), -(x / y / y)))
@scalar_rule(\(x, y), (-(y / x / x), inv(x)))
@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(x)))

#log(complex(x)) is require so it give correct complex answer for x<0
@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(complex(x))))

@scalar_rule(cbrt(x), inv(3 * Ω^2))
@scalar_rule(inv(x), -Ω^2)
Expand Down Expand Up @@ -117,7 +120,7 @@ end

function rrule(::typeof(*), x::Number, y::Number)
function times_pullback(ΔΩ)
return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
return (NO_FIELDS, @thunk(ΔΩ * y), @thunk(x * ΔΩ))
end
return x * y, times_pullback
end
Expand All @@ -132,3 +135,14 @@ function rrule(::typeof(identity), x)
end
return x, identity_pullback
end

function rrule(::typeof(identity), x::Tuple)
# `identity(::Tuple)` returns multiple outputs;because that is how we think of
# returning a tuple, so its pullback needs to accept multiple inputs.
# `identity(::Tuple)` has one input, so its pullback should return 1 matching output
# see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152
function identity_pullback(ȳs...)
return (NO_FIELDS, Composite{typeof(x)}(ȳs...))
end
return x, identity_pullback
end
23 changes: 11 additions & 12 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ end
##### `det`
#####

function frule((_, ẋ), ::typeof(det), x)
function frule((_, ẋ), ::typeof(det), x::Union{Number, AbstractMatrix})
Ω = det(x)
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω, Ω * tr(inv(x) * ẋ)
end

function rrule(::typeof(det), x)
function rrule(::typeof(det), x::Union{Number, AbstractMatrix})
Ω = det(x)
function det_pullback(ΔΩ)
return NO_FIELDS, @thunk(Ω * ΔΩ * inv(x)')
return NO_FIELDS, Ω * ΔΩ * transpose(inv(x))
end
return Ω, det_pullback
end
Expand All @@ -59,15 +59,15 @@ end
##### `logdet`
#####

function frule((_, Δx), ::typeof(logdet), x)
function frule((_, Δx), ::typeof(logdet), x::Union{Number, AbstractMatrix})
Ω = logdet(x)
return Ω, tr(inv(x) * Δx)
end

function rrule(::typeof(logdet), x)
function rrule(::typeof(logdet), x::Union{Number, AbstractMatrix})
Ω = logdet(x)
function logdet_pullback(ΔΩ)
return (NO_FIELDS, @thunk(ΔΩ * inv(x)'))
return (NO_FIELDS, ΔΩ * transpose(inv(x)))
end
return Ω, logdet_pullback
end
Expand All @@ -81,6 +81,8 @@ function frule((_, Δx), ::typeof(tr), x)
end

function rrule(::typeof(tr), x)
# This should really be a FillArray
# see https://github.com/JuliaDiff/ChainRules.jl/issues/46
function tr_pullback(ΔΩ)
return (NO_FIELDS, @thunk Diagonal(fill(ΔΩ, size(x, 1))))
end
Expand Down Expand Up @@ -121,14 +123,11 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
C, dC_pb = rrule(adjoint, Cᵀ)
function slash_pullback(Ȳ)
# Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want
# this is not a problem if you want the 2nd or 3rd, but if you want the first, it
# is fairly wasteful
_, dC = dC_pb(Ȳ)
_, dBᵀ, dAᵀ = dS_pb(extern(dC))
_, dBᵀ, dAᵀ = dS_pb(unthunk(dC))

# need to extern as dAᵀ, dBᵀ are generally `Thunk`s, which don't support adjoint
∂A = @thunk last(dA_pb(extern(dAᵀ)))
∂B = @thunk last(dA_pb(extern(dBᵀ)))
∂A = last(dA_pb(unthunk(dAᵀ)))
∂B = last(dA_pb(unthunk(dBᵀ)))

(NO_FIELDS, ∂A, ∂B)
end
Expand Down
11 changes: 9 additions & 2 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
#####

function rrule(::Type{<:Diagonal}, d::AbstractVector)
function Diagonal_pullback(ȳ)
return (NO_FIELDS, @thunk(diag(ȳ)))
function Diagonal_pullback(ȳ::AbstractMatrix)
return (NO_FIELDS, diag(ȳ))
end
function Diagonal_pullback(ȳ::Composite)
# TODO: Assert about the primal type in the Composite, It should be Diagonal
# infact it should be exactly the type of `Diagonal(d)`
# but right now Zygote loses primal type information so we can't use it.
# See https://github.com/FluxML/Zygote.jl/issues/603
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am confused by this: how do we end up with a Composite but not a Composite{P} for the correct P?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in the Zygote PR we turn all NamedTuples into Composite{Any}

return (NO_FIELDS, ȳ.diag)
end
return Diagonal(d), Diagonal_pullback
end
Expand Down
39 changes: 36 additions & 3 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@
test_scalar(acsc, 1/x)
test_scalar(acot, 1/x)
end
@testset "Inverse hyperbolic" for x = (0.5, Complex(0.5, 0.25))
@testset "Inverse hyperbolic" for x = (0.5, Complex(0.5, 0.25), Complex(-2.1 -3.1im))
test_scalar(asinh, x)
test_scalar(acosh, x + 1) # +1 accounts for domain
test_scalar(acosh, x + 1) # +1 accounts for domain for real
test_scalar(atanh, x)
test_scalar(asech, x)
test_scalar(acsch, x)
test_scalar(acoth, x + 1)
end

@testset "Inverse degrees" for x = (0.5, Complex(0.5, 0.25))
test_scalar(asind, x)
test_scalar(acosd, x)
Expand Down Expand Up @@ -100,7 +101,23 @@
end
end

@testset "*(x, y)" begin
@testset "*(x, y) (scalar)" begin
# This is pretty important so testing it fairly heavily
test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im)
@testset "$x * $y; (perturbed by: $perturb)" for
x in test_points, y in test_points, perturb in test_points

# give small off-set so as can't slip in symmetry
x̄ = ẋ = 0.5 + perturb
ȳ = ẏ = 0.6 + perturb
Δz = perturb

frule_test(*, (x, ẋ), (y, ẏ))
rrule_test(*, Δz, (x, x̄), (y, ȳ))
end
end

@testset "matmul *(x, y)" begin
x, y = rand(3, 2), rand(2, 5)
z, pullback = rrule(*, x, y)

Expand All @@ -125,10 +142,26 @@
rrule_test(f, Δz, (x, x̄), (y, ȳ))
end

@testset "x^n for x<0" begin
rng = MersenneTwister(123456)
x = -15*rand(rng)
Δx, x̄ = 10rand(rng, 2)
y, Δy, ȳ = rand(rng, 3)
Δz = rand(rng)

frule_test(^, (-x, Δx), (y, Δy))
rrule_test(^, Δz, (-x, x̄), (y, ȳ))
end

@testset "identity" begin
rng = MersenneTwister(1)
rrule_test(identity, randn(rng), (randn(rng), randn(rng)))
rrule_test(identity, randn(rng, 4), (randn(rng, 4), randn(rng, 4)))

rrule_test(
identity, Tuple(randn(rng, 3)),
(Composite{Tuple}(randn(rng, 3)...), Composite{Tuple}(randn(rng, 3)...))
)
end

@testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0+200im)
Expand Down
8 changes: 8 additions & 0 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
rrule_test(Diagonal, D, (randn(rng, N), randn(rng, N)))
# Concrete type instead of UnionAll
rrule_test(typeof(D), D, (randn(rng, N), randn(rng, N)))

# TODO: replace this with a `rrule_test` once we have that working
# see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/24
res, pb = rrule(Diagonal, [1, 4])
@test pb(10*res) == (NO_FIELDS, [10, 40])
comp = Composite{typeof(res)}(; diag=10*res.diag) # this is the structure of Diagonal
@test pb(comp) == (NO_FIELDS, [10, 40])
end

@testset "::Diagonal * ::AbstractVector" begin
rng, N = MersenneTwister(123456), 3
rrule_test(
Expand Down