Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test errors and update to ChainRules 1 and Distributions 0.25.15 #198

Merged
merged 22 commits into from
Aug 31, 2021

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Aug 24, 2021

This PR builds on #197 and additionally updates the ChainRules definitions to ChainRules 1.

Unfortunately, locally Zygote tests seem to fail even with ChainRules 1. I opened a separate PR to not pollute #197 with unrelated changes.

Edit:

This PR

@devmotion
Copy link
Member Author

Ah... NNlib with CR 1 only supports Julia 1.6. IIRC it's only used in the definition of a helper function in the tests so maybe we should just remove it from the test dependencies.

@devmotion
Copy link
Member Author

@yebai I don't think we should move the ZygoteRules definitions in https://github.com/TuringLang/DistributionsAD.jl/blob/a603603e89fbafd4cfde6a69eb4c44874ca75cb7/src/zygote.jl to ChainRules. These are no useful differentiation rules but just hacks and workarounds for issues in Zygote (maybe some are not needed anymore?), so IMO we should not define them generally for oter AD systems even though we could call back into the AD system with CR1, similar to ZygoteRules.pullback. If we want to define them for Zygote exclusively with ChainRules.rrule, we could add a config::Zygote.ZygoteRulesConfig as first argument - but then we either have to depend on Zygote or load Zygote with Requires. In my opinion both options are worse than depending on the lightweight ZygoteRules package and so I don't think we should convert these rules to ChainRules.

@devmotion
Copy link
Member Author

Probably we need more time to debug and fix the Zygote issues, it's still not clear to me why it fails and in which Zygote/CR release it was broken and how it can be fixed. However, it seems DistributionsAD holds back other packages that would like to update to ChainRules 1 (e.g. SciML/SciMLSensitivity.jl#467). I assume that these Zygote issues are not new but existed for a while, in the tests in recent PRs they were only masked by the other failing tests I assume. Should we merge and release this PR once we are satisfied even if Zygote tests fail?

@devmotion
Copy link
Member Author

devmotion commented Aug 27, 2021

OK, so here's a MWE:

julia> using DistributionsAD, Zygote, Distributions

julia> Zygote.gradient(0.45) do p
           return sum(logpdf(arraydist([Bernoulli(p) for _ in 1:3]), zeros(3, 2)))
       end
(-10.909090909090908,)

julia> Zygote.gradient(0.45) do p
           return loglikelihood(arraydist([Bernoulli(p) for _ in 1:3]), zeros(3, 2))
       end
ERROR: Need an adjoint for constructor Bernoulli{Float64}. Gradient is of type ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{Bernoulli{Float64}, Nothing, false})(Δ::ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/YZfhu/src/lib/lib.jl:323
  [3] (::Zygote.var"#1812#back#229"{Zygote.Jnew{Bernoulli{Float64}, Nothing, false}})(Δ::ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ~/.julia/packages/Distributions/fXTVC/src/univariate/discrete/bernoulli.jl:30 [inlined]
  [5] (::typeof((Bernoulli{Float64})))(Δ::ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/Distributions/fXTVC/src/univariate/discrete/bernoulli.jl:35 [inlined]
  [7] (::typeof((#Bernoulli#41)))(Δ::ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/Distributions/fXTVC/src/univariate/discrete/bernoulli.jl:34 [inlined]
  [9] Pullback
    @ ./none:0 [inlined]
 [10] (::typeof((λ)))(Δ::ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface2.jl:0
 [11] #566
    @ ~/.julia/packages/Zygote/YZfhu/src/lib/array.jl:218 [inlined]
 [12] #4
    @ ./generator.jl:36 [inlined]
 [13] iterate
    @ ./generator.jl:47 [inlined]
 [14] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{Tuple{Bernoulli{Float64}, typeof((λ))}}, Vector{ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}}}}, Base.var"#4#5"{Zygote.var"#566#572"}})
    @ Base ./array.jl:678
 [15] map
    @ ./abstractarray.jl:2383 [inlined]
 [16] (::Zygote.var"#563#569"{var"#34#36"{Float64}, Tuple{UnitRange{Int64}}, Vector{Tuple{Bernoulli{Float64}, typeof((λ))}}})(Δ::Vector{ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}})
    @ Zygote ~/.julia/packages/Zygote/YZfhu/src/lib/array.jl:218
 [17] (::Zygote.var"#back#602"{Zygote.var"#563#569"{var"#34#36"{Float64}, Tuple{UnitRange{Int64}}, Vector{Tuple{Bernoulli{Float64}, typeof((λ))}}}})(ȳ::Vector{ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}})
    @ Zygote ~/.julia/packages/Zygote/YZfhu/src/lib/array.jl:252
 [18] Pullback
    @ ./REPL[11]:2 [inlined]
 [19] (::typeof((#33)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface2.jl:0
 [20] (::Zygote.var"#50#51"{typeof((#33))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface.jl:41
 [21] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/YZfhu/src/compiler/interface.jl:76
 [22] top-level scope
    @ REPL[11]:1

Both approaches should return the same value. A single example works:

julia> Zygote.gradient(0.45) do p
           return logpdf(arraydist([Bernoulli(p) for _ in 1:3]), zeros(3))
       end
(-5.454545454545454,)

julia> Zygote.gradient(0.45) do p
           return loglikelihood(arraydist([Bernoulli(p) for _ in 1:3]), zeros(3))
       end
(-5.454545454545454,)

And filldist works with multiple samples and both logpdf and loglikelihood:

julia> Zygote.gradient(0.45) do p
           return sum(logpdf(filldist(Bernoulli(p), 3), zeros(3, 2)))
       end
(-10.909090909090908,)

julia> Zygote.gradient(0.45) do p
           return loglikelihood(filldist(Bernoulli(p), 3), zeros(3, 2))
       end
(-10.909090909090908,)

@devmotion
Copy link
Member Author

OK, even simpler:

julia> Zygote.gradient(0.45) do p
           d = arraydist([Bernoulli(p) for _ in 1:3])
           x = zeros(3, 2)
           return sum(Base.Fix1(Distributions._logpdf, d), eachcol(x))
       end
(-10.909090909090908,)

julia> Zygote.gradient(0.45) do p
           d = arraydist([Bernoulli(p) for _ in 1:3])
           x = zeros(3, 2)
           return sum(xi -> Distributions._logpdf(d, xi), eachcol(x))
       end
(-10.909090909090908,)

julia> Zygote.gradient(0.45) do p
           d = arraydist([Bernoulli(p) for _ in 1:3])
           x = zeros(3, 2)
           return sum(i -> Distributions._logpdf(d, view(x, :, i)), axes(x, 2))
       end
ERROR: Need an adjoint for constructor Bernoulli{Float64}. Gradient is of type ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
...

julia> Zygote.gradient(0.45) do p
           d = arraydist([Bernoulli(p) for _ in 1:3])
           x = zeros(3, 2)
           return sum(i -> Distributions._logpdf(d, view(x, :, i)), 1:size(x, 2))
       end
ERROR: Need an adjoint for constructor Bernoulli{Float64}. Gradient is of type ChainRulesCore.Tangent{Bernoulli{Float64}, NamedTuple{(:p,), Tuple{Float64}}}
Stacktrace:
  [1] error(s::String)
...

The last example is exactly how loglikelihood for multiple samples is defined in Distributions: https://github.com/JuliaStats/Distributions.jl/blob/59df675409a7e2490e4a45edd32c0267df435c55/src/multivariates.jl#L267

@oxinabox @DhairyaLGandhi Any idea why the first examples work but the last ones fail?

@devmotion
Copy link
Member Author

I added a workaround for the issue above, hopefully a similar workaround can fix the matrixvariate issues.

The StatsFuns CR definitions are removed since they were moved to StatsFuns (available in >= 0.9.10). I opened a PR to Distributions that transfers the CR definition for the pdf of PoissonBinomial: JuliaStats/Distributions.jl#1390 If it is merged, we should remove the definitions in DistributionsAD.

end
function ChainRulesCore.rrule(::typeof(to_simplex), x::AbstractArray{<:Real})
y = to_simplex(x)
pullback(ȳ) = (NoTangent(), to_simplex_pullback(ȳ, y))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May want to think about a ProbectTo here, but that is beyond scope of this PR I guess

@devmotion
Copy link
Member Author

Zygote tests pass again 🎉

@devmotion devmotion changed the title Extension of #197 with ChainRules 1 Fix test errors and update to ChainRules 1 and Distributions 0.25.15 Aug 30, 2021
# Loglikelihood of multi- and matrixvariate distributions: multiple samples
# workaround for Zygote issues discussed in
# https://github.com/TuringLang/DistributionsAD.jl/pull/198
ZygoteRules.@adjoint function Distributions.loglikelihood(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this - it looks good to me!

@@ -73,8 +73,8 @@
DistSpec(Beta, (1.0, 2.0), 0.5),

DistSpec(BetaPrime, (), 0.5),
DistSpec(BetaPrime, (1.0,), 0.5),
DistSpec(BetaPrime, (1.0, 2.0), 0.5),
DistSpec(BetaPrime, (1.5,), 0.5),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed the same issue but thought it only affects ReverseDiff with tape caching. When tape caching is disabled, ReverseDiff should handle control flows correctly (similar to Tracker).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's ReverseDiff specific since e.g. Tracker failed as well (https://github.com/TuringLang/DistributionsAD.jl/pull/185/checks?check_run_id=3338494448#step:5:1097). In fact all AD systems apart from Zygote return a value that is different from finite differencing:

julia> using Distributions, FiniteDifferences, Tracker, ReverseDiff, ForwardDiff, Zygote

julia> f(x) = logpdf(BetaPrime(x), 0.5)
f (generic function with 1 method)

julia> central_fdm(5, 1)(f, 1.0)
0.495922603223565

julia> central_fdm(12, 1)(f, 1.0)
0.4959226032237455

julia> Tracker.gradient(f, 1.0)
(1.1890697837836721 (tracked),)

julia> ForwardDiff.derivative(f, 1.0)
1.1890697837836721

julia> ReverseDiff.gradient(f  first, [1.0])
1-element Vector{Float64}:
 1.1890697837836721

julia> Zygote.gradient(f, 1.0)
(0.49592260322372683,)

The different values arise from the different behaviour of the AD systems for x -> xlogy(x, 0.5) at 0. FiniteDifferences and Zygote return the correct derivative (log(0.5)) whereas Tracker, ForwardDiff, and ReverseDiff return 0:

julia> using LogExpFunctions

julia> central_fdm(12, 1)(x -> xlogy(x, 0.5), 0.0)
-0.6931471805599452

julia> Tracker.gradient(x -> xlogy(x, 0.5), 0.0)
(0.0 (tracked),)

julia> ForwardDiff.derivative(x -> xlogy(x, 0.5), 0.0)
0.0

julia> ReverseDiff.gradient(x -> xlogy(first(x), 0.5), [0.0])
1-element Vector{Float64}:
 0.0

julia> Zygote.gradient(x -> xlogy(x, 0.5), 0.0)
(-0.6931471805599453,)

It seems the AD systems handle the branch in https://github.com/JuliaStats/LogExpFunctions.jl/blob/39223ba1daa0244fae1885cd0bcba5a25743349b/src/basicfuns.jl#L29 differently/incorrectly.

@yebai
Copy link
Member

yebai commented Aug 31, 2021

@yebai I don't think we should move the ZygoteRules definitions in https://github.com/TuringLang/DistributionsAD.jl/blob/a603603e89fbafd4cfde6a69eb4c44874ca75cb7/src/zygote.jl to ChainRules. These are no useful differentiation rules but just hacks and workarounds for issues in Zygote

It sounds good to keep them here, or depreciate them if possible, or move them into another package (e.g. DynamicPPL) where appropriate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants