-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote ad backend for normalizing flows #154
Comments
I can't run your example and it's a bit difficult to comment without knowing the exact error message. It is also a bit unclear to me which packages you used here - did you load DistributionsAD? |
@devmotion I've updated the example with the full code. |
Thanks, now I can run the code. The error is caused by AD problems of the Roots package. The inverse of planar layers is computed using a root-finding algorithm in the Roots package (see Bijectors.jl/src/bijectors/planar_layer.jl Line 122 in a854144
|
BTW the different AD backends (Tracker, ForwardDiff, ReverseDiff, and Zygote) all have different advantages and disadvantages and usually the optimal choice depends on the problem and your implementation. Tracker is not discontinued, similar to ForwardDiff it is solid and maintained but it is not planned to add any major new features. |
Thanks! I will open a issue on Zygote to see if there is any plans to support it, as it's has some other advantages. |
just to update the error with the latest packages updates. Compiling Tuple{typeof(Bijectors.find_alpha),Float64,Float64,Float64}: try/catch is not supported.
|
This is fixed by #160. In general, it is a bad idea to just differentiate through |
Hi,
The Normalizing flow example uses Tracker, a discontinued AD package.
I am trying to fit a NF using Zygote, but I have some problems.
Example:
I get the error:
Mutating arrays is not supported
on
gs = back(one(train_loss))
Any way I can make this work?
Thanks
The text was updated successfully, but these errors were encountered: