Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote ad backend for normalizing flows #154

Closed
themrzmaster opened this issue Jan 4, 2021 · 7 comments
Closed

Zygote ad backend for normalizing flows #154

themrzmaster opened this issue Jan 4, 2021 · 7 comments

Comments

@themrzmaster
Copy link

themrzmaster commented Jan 4, 2021

Hi,
The Normalizing flow example uses Tracker, a discontinued AD package.
I am trying to fit a NF using Zygote, but I have some problems.

Example:

@model function gen()
    x ~ Exponential(0.5)
    y ~ Normal(0, x)
end
s = sample(gen(), NUTS(), 500)
train_data = hcat(s[:x].data, s[:y].data)' |> Array

b = PlanarLayer(2) ∘ LeakyReLU{Float64, 1}(0.05) ∘ RadialLayer(2) ∘ LeakyReLU{Float64, 1}(0.05) ∘ PlanarLayer(2)
d = MvNormal(zeros(2), ones(2))
tb = transformed(d, b);
loss(tb :: Bijectors.TransformedDistribution, x :: Matrix{Float64}) = begin
    return sum(-logpdf(tb, x))
end
function nf_train(tb, x, opt, ps, epochs)
    @showprogress for i ∈ 1:epochs
        train_loss, back = Zygote.pullback(() -> loss(tb, x), ps)
        gs = back(one(train_loss))
        Flux.update!(opt, ps, gs)
    end
end
nf_train(tb, train_data, ADAM(), Zygote.Params(Flux.params(b)), 10)

I get the error:

Mutating arrays is not supported

on

gs = back(one(train_loss))

Any way I can make this work?

Thanks

@devmotion
Copy link
Member

I can't run your example and it's a bit difficult to comment without knowing the exact error message. It is also a bit unclear to me which packages you used here - did you load DistributionsAD?

@themrzmaster
Copy link
Author

@devmotion I've updated the example with the full code.
Yes, I've loaded DistributionsAD
using Turing, Bijectors, Flux, ProgressMeter, DistributionsAD

@devmotion
Copy link
Member

Thanks, now I can run the code. The error is caused by AD problems of the Roots package. The inverse of planar layers is computed using a root-finding algorithm in the Roots package (see

alpha = find_zero(initial_bracket) do x
and the paper mentioned in the comments), and it seems it is not compatible with Zygote.

@devmotion
Copy link
Member

BTW the different AD backends (Tracker, ForwardDiff, ReverseDiff, and Zygote) all have different advantages and disadvantages and usually the optimal choice depends on the problem and your implementation. Tracker is not discontinued, similar to ForwardDiff it is solid and maintained but it is not planned to add any major new features.

@themrzmaster
Copy link
Author

Thanks! I will open a issue on Zygote to see if there is any plans to support it, as it's has some other advantages.

@themrzmaster
Copy link
Author

themrzmaster commented Jan 10, 2021

just to update the error with the latest packages updates.

Compiling Tuple{typeof(Bijectors.find_alpha),Float64,Float64,Float64}: try/catch is not supported.

Stacktrace:
[1] error(::String) at ./error.jl:33
[2] instrument(::IRTools.Inner.IR) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/reverse.jl:89
[3] #Primal#20 at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/reverse.jl:170 [inlined]
[4] Zygote.Adjoint(::IRTools.Inner.IR; varargs::Nothing, normalise::Bool) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/reverse.jl:283
[5] _lookup_grad(::Type{T} where T) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/emit.jl:101
[6] #s2937#1244 at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:37 [inlined]
[7] #s2937#1244(::Any, ::Any, ::Any) at ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:527
[9] #1079 at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/lib/broadcast.jl:150 [inlined]
[10] #3844#back at /Users/luccazenobio/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[11] (::Zygote.var"#150#151"{Zygote.var"#3844#back#1082"{Zygote.var"#1079#1081"{typeof(∂(find_alpha))}},Tuple{NTuple{4,Nothing},Tuple{Nothing}}})(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/lib/lib.jl:191
[12] #1693#back at /Users/luccazenobio/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[13] broadcasted at ./broadcast.jl:1263 [inlined]
[14] Inverse at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/bijectors/planar_layer.jl:117 [inlined]
[15] (::typeof(∂(λ)))(::Array{Float64,1}) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[16] logabsdetjac at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/interface.jl:85 [inlined]
[17] (::typeof(∂(logabsdetjac)))(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[18] forward at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/interface.jl:98 [inlined]
[19] (::typeof(∂(forward)))(::NamedTuple{(:rv, :logabsdetjac),Tuple{Array{Float64,1},Float64}}) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[20] macro expansion at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/bijectors/composed.jl:0 [inlined]
[21] forward at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/bijectors/composed.jl:219 [inlined]
[22] (::typeof(∂(forward)))(::NamedTuple{(:rv, :logabsdetjac),Tuple{Array{Float64,1},Float64}}) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[23] _logpdf at /Users/luccazenobio/.julia/packages/Bijectors/0PJJc/src/transformed_distribution.jl:105 [inlined]
[24] (::typeof(∂(_logpdf)))(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[25] #323 at /Users/luccazenobio/.julia/packages/DistributionsAD/HvoZ3/src/zygote.jl:85 [inlined]
[26] (::typeof(∂(λ)))(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[27] #502 at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/lib/array.jl:187 [inlined]
[28] #3 at ./generator.jl:36 [inlined]
[29] iterate at ./generator.jl:47 [inlined]
[30] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof(∂(λ)),1},Array{Float64,1}}},Base.var"#3#4"{Zygote.var"#502#506"}}) at ./array.jl:686
[31] map at ./abstractarray.jl:2248 [inlined]
[32] (::Zygote.var"#501#505"{Array{typeof(∂(λ)),1}})(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/lib/array.jl:187
[33] #2537#back at /Users/luccazenobio/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[34] #322 at /Users/luccazenobio/.julia/packages/DistributionsAD/HvoZ3/src/zygote.jl:85 [inlined]
[35] (::typeof(∂(#322)))(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[36] #41 at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:40 [inlined]
[37] #532#back at /Users/luccazenobio/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[38] loss at ./In[13]:12 [inlined]
[39] #31 at ./In[13]:16 [inlined]
[40] (::typeof(∂(λ)))(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
[41] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /Users/luccazenobio/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:172
[42] macro expansion at ./In[13]:17 [inlined]
[43] macro expansion at /Users/luccazenobio/.julia/packages/ProgressMeter/GhAId/src/ProgressMeter.jl:762 [inlined]
[44] nf_train(::TransformedDistribution{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},Composed{Tuple{PlanarLayer{Array{Float64,1},Array{Float64,1}},LeakyReLU{Float64,1},RadialLayer{Array{Float64,1},Array{Float64,1}},LeakyReLU{Float64,1},PlanarLayer{Array{Float64,1},Array{Float64,1}}},1},Multivariate}, ::Array{Float64,2}, ::ADAM, ::Zygote.Params, ::Int64) at ./In[13]:15
[45] top-level scope at In[13]:21
[46] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091

@devmotion
Copy link
Member

This is fixed by #160. In general, it is a bad idea to just differentiate through find_alpha and much more efficient to implement the derivatives explicitly, as done in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants