-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compatibility with Flux #59
Comments
If I'm not mistaken, the definitions in https://github.com/PumasAI/SimpleChains.jl/blob/main/src/chain_rules.jl should be more than enough for this purpose. For example, your code snippet already works if I change using Flux: mse # only required for the loss function. If you're making your own layers in a library, just write:
using Functors # for interop with Flux's module system
using Optimisers, Zygote # training loop essentials sans Flux
using SimpleChains
# Example drop-in layer that will work wherever a Flux model is expected.
# Zygote and Optimisers will let you train wrt. the params vector directly,
# but bundling state and behaviour opens up the possiblility of working with higher-level libraries like FastAI.jl.
struct WrappedSimpleChain{M,P}
model::M
params::P
end
@functor WrappedSimpleChain
(m::WrappedSimpleChain)(x) = m.model(x, m.params)
loss(m, x, y) = mse(m(x), y)
let
sc = SimpleChain(8, TurboDense(identity, 8))
p = SimpleChains.init_params(sc)
model = WrappedSimpleChain(sc, p)
opt_state = Optimisers.setup(Optimisers.ADAM(), model)
x = randn(Float32, 8, 1)
y = 0.75x
for i in 1:10
grads, = gradient(model) do m
loss(m, x, y)
end
opt_state, model = Optimisers.update!(opt_state, model, grads)
@info i, loss(model, x, y)
end
end So not only does the compatibility appear to be there, but it looks pretty future-proof as well! |
Note that The But simple calls/gradient calculations are (a) not multithreaded and (b) manually batching and calling with |
I think it would be really useful to allow interoperability with things like Flux, Zygote, etc...
I think much of the benefit of this package would be applicable to other projects if a
SimpleChain
could be used without attaching a loss or optimizer or anything else, but simply used for its allocation-free forward and backward pass. So you should be able to do something like:While we would still deal with the overhead of Zygote, Flux's optimizers, etc.... we would at least be able to eliminate our model's allocation burden, which may helpful for many users.
The text was updated successfully, but these errors were encountered: