-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using Zygote #669
using Zygote #669
Conversation
In my own experiments trying to use zyg_update!(opt, model, updates::Nothing) = nothing
function zyg_update!(opt, model::AbstractArray, updates::AbstractArray)
# Sub off to Flux's ADAM optimizer
Δ = Flux.Optimise.update!(opt, model, updates)
return model .-= Δ
end
function zyg_update!(opt, model, updates)
if nfields(model) == 0
return model
end
for field_idx in 1:nfields(model)
zyg_update!(opt, getfield(model, field_idx), getfield(updates, field_idx))
end
end Things actually work fairly well, except BatchNorm freaks out, complaining about mutating arrays. To work around this, I am using my own |
I wanted to use this with the new NNlib overhaul, so I rebased this branch on top of |
docs/src/training/optimisers.md
Outdated
@@ -3,25 +3,25 @@ | |||
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`. | |||
|
|||
```julia | |||
using Flux, Flux.Tracker | |||
using Flux, Flux.Zygote |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dhairyagandhi96 Flux already exports gradient
, so this may not be necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
Co-Authored-By: Mike J Innes <mike.j.innes@gmail.com>
bors r+ |
669: using Zygote r=MikeInnes a=MikeInnes Otherwise known as "break all the things". This will be a huge change so I'm beginning to prepare now, even though Zygote is still a couple of months off from being really ready. **Do not try this at home** (yet) – this branch is eventually aimed at beta testers, but isn't even ready for that yet. The idea is to break as little code as possible, which means supporting the current `Params` API; but I also want to start prototyping the nicer things discussed in #628 and other issues. Blocking issues: * [x] Get the tests passing. * [x] Check tests on GPU. * [x] Rewrite all the docs. * [x] Cache invalidation (JuliaLabs/Cassette.jl#6). * [x] Moving over adjoints (FluxML/Zygote.jl#81). * [x] General Zygote robustness. Nice to have: * [ ] Robust nested AD (may not be a blocker if one can still use Tracker with Flux). * [x] Zygote support for modules / globals as discussed in #628, along with #637. * [x] Better train/test mode as in #643. If you're the kind of person who ignores triangular road signs, you can try this with ```julia ]add Flux#zygote Zygote#master ``` Co-authored-by: Mike J Innes <mike.j.innes@gmail.com> Co-authored-by: Elliot Saba <staticfloat@gmail.com> Co-authored-by: thebhatman <manjunathbhat9920@gmail.com>
Seem to be some issues with our GPU CI, so just merging. |
Build failed |
|
||
rnn.state = Tracker.data(rnn.state) | ||
""" | ||
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, is there an alternative?
Otherwise known as "break all the things". This will be a huge change so I'm beginning to prepare now, even though Zygote is still a couple of months off from being really ready. Do not try this at home (yet) – this branch is eventually aimed at beta testers, but isn't even ready for that yet.
The idea is to break as little code as possible, which means supporting the current
Params
API; but I also want to start prototyping the nicer things discussed in #628 and other issues.Blocking issues:
Nice to have:
If you're the kind of person who ignores triangular road signs, you can try this with
]add Flux#zygote Zygote#master