-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix test errors and update to ChainRules 1 and Distributions 0.25.15 #198
Conversation
Ah... NNlib with CR 1 only supports Julia 1.6. IIRC it's only used in the definition of a helper function in the tests so maybe we should just remove it from the test dependencies. |
@yebai I don't think we should move the ZygoteRules definitions in https://github.com/TuringLang/DistributionsAD.jl/blob/a603603e89fbafd4cfde6a69eb4c44874ca75cb7/src/zygote.jl to ChainRules. These are no useful differentiation rules but just hacks and workarounds for issues in Zygote (maybe some are not needed anymore?), so IMO we should not define them generally for oter AD systems even though we could call back into the AD system with CR1, similar to |
Probably we need more time to debug and fix the Zygote issues, it's still not clear to me why it fails and in which Zygote/CR release it was broken and how it can be fixed. However, it seems DistributionsAD holds back other packages that would like to update to ChainRules 1 (e.g. SciML/SciMLSensitivity.jl#467). I assume that these Zygote issues are not new but existed for a while, in the tests in recent PRs they were only masked by the other failing tests I assume. Should we merge and release this PR once we are satisfied even if Zygote tests fail? |
OK, so here's a MWE: julia> using DistributionsAD, Zygote, Distributions
julia> Zygote.gradient(0.45) do p
return sum(logpdf(arraydist([Bernoulli(p) for _ in 1:3]), zeros(3, 2)))
end
(-10.909090909090908,)
julia> Zygote.gradient(0.45) do p
return loglikelihood(arraydist([Bernoulli(p) for _ in 1:3]), zeros(3, 2))
end
ERROR: Need an adjoint for constructor Bernoulli{Float64}. Gradient is of type ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.Jnew{Bernoulli{Float64}, Nothing, false})(Δ::ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}})
@ Zygote ~/.julia/packages/Zygote/YZfhu/src/lib/lib.jl:323
[3] (::Zygote.var"#1812#back#229"{Zygote.Jnew{Bernoulli{Float64}, Nothing, false}})(Δ::ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] Pullback
@ ~/.julia/packages/Distributions/fXTVC/src/univariate/discrete/bernoulli.jl:30 [inlined]
[5] (::typeof(∂(Bernoulli{Float64})))(Δ::ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}})
@ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/Distributions/fXTVC/src/univariate/discrete/bernoulli.jl:35 [inlined]
[7] (::typeof(∂(#Bernoulli#41)))(Δ::ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}})
@ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/Distributions/fXTVC/src/univariate/discrete/bernoulli.jl:34 [inlined]
[9] Pullback
@ ./none:0 [inlined]
[10] (::typeof(∂(λ)))(Δ::ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}})
@ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface2.jl:0
[11] #566
@ ~/.julia/packages/Zygote/YZfhu/src/lib/array.jl:218 [inlined]
[12] #4
@ ./generator.jl:36 [inlined]
[13] iterate
@ ./generator.jl:47 [inlined]
[14] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{Tuple{Bernoulli{Float64}, typeof(∂(λ))}}, Vector{ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}}}}, Base.var"#4#5"{Zygote.var"#566#572"}})
@ Base ./array.jl:678
[15] map
@ ./abstractarray.jl:2383 [inlined]
[16] (::Zygote.var"#563#569"{var"#34#36"{Float64}, Tuple{UnitRange{Int64}}, Vector{Tuple{Bernoulli{Float64}, typeof(∂(λ))}}})(Δ::Vector{ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}})
@ Zygote ~/.julia/packages/Zygote/YZfhu/src/lib/array.jl:218
[17] (::Zygote.var"#back#602"{Zygote.var"#563#569"{var"#34#36"{Float64}, Tuple{UnitRange{Int64}}, Vector{Tuple{Bernoulli{Float64}, typeof(∂(λ))}}}})(ȳ::Vector{ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}})
@ Zygote ~/.julia/packages/Zygote/YZfhu/src/lib/array.jl:252
[18] Pullback
@ ./REPL[11]:2 [inlined]
[19] (::typeof(∂(#33)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface2.jl:0
[20] (::Zygote.var"#50#51"{typeof(∂(#33))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface.jl:41
[21] gradient(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface.jl:76
[22] top-level scope
@ REPL[11]:1 Both approaches should return the same value. A single example works: julia> Zygote.gradient(0.45) do p
return logpdf(arraydist([Bernoulli(p) for _ in 1:3]), zeros(3))
end
(-5.454545454545454,)
julia> Zygote.gradient(0.45) do p
return loglikelihood(arraydist([Bernoulli(p) for _ in 1:3]), zeros(3))
end
(-5.454545454545454,) And julia> Zygote.gradient(0.45) do p
return sum(logpdf(filldist(Bernoulli(p), 3), zeros(3, 2)))
end
(-10.909090909090908,)
julia> Zygote.gradient(0.45) do p
return loglikelihood(filldist(Bernoulli(p), 3), zeros(3, 2))
end
(-10.909090909090908,) |
OK, even simpler: julia> Zygote.gradient(0.45) do p
d = arraydist([Bernoulli(p) for _ in 1:3])
x = zeros(3, 2)
return sum(Base.Fix1(Distributions._logpdf, d), eachcol(x))
end
(-10.909090909090908,)
julia> Zygote.gradient(0.45) do p
d = arraydist([Bernoulli(p) for _ in 1:3])
x = zeros(3, 2)
return sum(xi -> Distributions._logpdf(d, xi), eachcol(x))
end
(-10.909090909090908,)
julia> Zygote.gradient(0.45) do p
d = arraydist([Bernoulli(p) for _ in 1:3])
x = zeros(3, 2)
return sum(i -> Distributions._logpdf(d, view(x, :, i)), axes(x, 2))
end
ERROR: Need an adjoint for constructor Bernoulli{Float64}. Gradient is of type ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
...
julia> Zygote.gradient(0.45) do p
d = arraydist([Bernoulli(p) for _ in 1:3])
x = zeros(3, 2)
return sum(i -> Distributions._logpdf(d, view(x, :, i)), 1:size(x, 2))
end
ERROR: Need an adjoint for constructor Bernoulli{Float64}. Gradient is of type ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}
Stacktrace:
[1] error(s::String)
... The last example is exactly how @oxinabox @DhairyaLGandhi Any idea why the first examples work but the last ones fail? |
I added a workaround for the issue above, hopefully a similar workaround can fix the matrixvariate issues. The StatsFuns CR definitions are removed since they were moved to StatsFuns (available in >= 0.9.10). I opened a PR to Distributions that transfers the CR definition for the pdf of |
end | ||
function ChainRulesCore.rrule(::typeof(to_simplex), x::AbstractArray{<:Real}) | ||
y = to_simplex(x) | ||
pullback(ȳ) = (NoTangent(), to_simplex_pullback(ȳ, y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May want to think about a ProbectTo
here, but that is beyond scope of this PR I guess
Zygote tests pass again 🎉 |
# Loglikelihood of multi- and matrixvariate distributions: multiple samples | ||
# workaround for Zygote issues discussed in | ||
# https://github.com/TuringLang/DistributionsAD.jl/pull/198 | ||
ZygoteRules.@adjoint function Distributions.loglikelihood( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into this - it looks good to me!
@@ -73,8 +73,8 @@ | |||
DistSpec(Beta, (1.0, 2.0), 0.5), | |||
|
|||
DistSpec(BetaPrime, (), 0.5), | |||
DistSpec(BetaPrime, (1.0,), 0.5), | |||
DistSpec(BetaPrime, (1.0, 2.0), 0.5), | |||
DistSpec(BetaPrime, (1.5,), 0.5), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also noticed the same issue but thought it only affects ReverseDiff with tape caching. When tape caching is disabled, ReverseDiff should handle control flows correctly (similar to Tracker).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's ReverseDiff specific since e.g. Tracker failed as well (https://github.com/TuringLang/DistributionsAD.jl/pull/185/checks?check_run_id=3338494448#step:5:1097). In fact all AD systems apart from Zygote return a value that is different from finite differencing:
julia> using Distributions, FiniteDifferences, Tracker, ReverseDiff, ForwardDiff, Zygote
julia> f(x) = logpdf(BetaPrime(x), 0.5)
f (generic function with 1 method)
julia> central_fdm(5, 1)(f, 1.0)
0.495922603223565
julia> central_fdm(12, 1)(f, 1.0)
0.4959226032237455
julia> Tracker.gradient(f, 1.0)
(1.1890697837836721 (tracked),)
julia> ForwardDiff.derivative(f, 1.0)
1.1890697837836721
julia> ReverseDiff.gradient(f ∘ first, [1.0])
1-element Vector{Float64}:
1.1890697837836721
julia> Zygote.gradient(f, 1.0)
(0.49592260322372683,)
The different values arise from the different behaviour of the AD systems for x -> xlogy(x, 0.5)
at 0
. FiniteDifferences and Zygote return the correct derivative (log(0.5)
) whereas Tracker, ForwardDiff, and ReverseDiff return 0:
julia> using LogExpFunctions
julia> central_fdm(12, 1)(x -> xlogy(x, 0.5), 0.0)
-0.6931471805599452
julia> Tracker.gradient(x -> xlogy(x, 0.5), 0.0)
(0.0 (tracked),)
julia> ForwardDiff.derivative(x -> xlogy(x, 0.5), 0.0)
0.0
julia> ReverseDiff.gradient(x -> xlogy(first(x), 0.5), [0.0])
1-element Vector{Float64}:
0.0
julia> Zygote.gradient(x -> xlogy(x, 0.5), 0.0)
(-0.6931471805599453,)
It seems the AD systems handle the branch in https://github.com/JuliaStats/LogExpFunctions.jl/blob/39223ba1daa0244fae1885cd0bcba5a25743349b/src/basicfuns.jl#L29 differently/incorrectly.
It sounds good to keep them here, or depreciate them if possible, or move them into another package (e.g. |
This PR builds on #197 and additionally updates the ChainRules definitions to ChainRules 1.
Unfortunately, locally Zygote tests seem to fail even with ChainRules 1. I opened a separate PR to not pollute #197 with unrelated changes.
Edit:
This PR
PoissonBinomial
which was moved to Distributions (Add ChainRules definitions for pdf ofPoissonBinomial
JuliaStats/Distributions.jl#1390); Distributions dependency is bumped to 0.25.15 which contains this ruleBetaPrime
distribution has broken tests with the recent release of Distributions. #196MvNormal
constructors