Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with Flux #59

Open
staticfloat opened this issue Apr 23, 2022 · 2 comments
Open

Compatibility with Flux #59

staticfloat opened this issue Apr 23, 2022 · 2 comments

Comments

@staticfloat
Copy link

staticfloat commented Apr 23, 2022

I think it would be really useful to allow interoperability with things like Flux, Zygote, etc...

I think much of the benefit of this package would be applicable to other projects if a SimpleChain could be used without attaching a loss or optimizer or anything else, but simply used for its allocation-free forward and backward pass. So you should be able to do something like:

using SimpleChains, Flux

model = SimpleChain(8, TurboDense(SimpleChains.tanh, 8))
p = SimpleChains.init_params(model)
Flux.train!(
    # 'loss' function, returns the value to be minimized
    (x, y) -> Flux.Losses.mse(model(x, p), y),
    # Parameters to be optimized over
    p,
    # dataset
    [(randn(8,1), randn(8,1))],
    # optimizer
    Flux.Optimise.ADAM()
)

While we would still deal with the overhead of Zygote, Flux's optimizers, etc.... we would at least be able to eliminate our model's allocation burden, which may helpful for many users.

@staticfloat staticfloat changed the title Standardize ML API Compatibility with Flux Apr 23, 2022
@ToucheSir
Copy link

If I'm not mistaken, the definitions in https://github.com/PumasAI/SimpleChains.jl/blob/main/src/chain_rules.jl should be more than enough for this purpose. For example, your code snippet already works if I change p -> Flux.params([p]). More complex model configurations are also possible:

using Flux: mse # only required for the loss function. If you're making your own layers in a library, just write:
using Functors # for interop with Flux's module system
using Optimisers, Zygote # training loop essentials sans Flux
using SimpleChains

# Example drop-in layer that will work wherever a Flux model is expected.
# Zygote and Optimisers will let you train wrt. the params vector directly,
# but bundling state and behaviour opens up the possiblility of working with higher-level libraries like FastAI.jl.
struct WrappedSimpleChain{M,P}
  model::M
  params::P
end
@functor WrappedSimpleChain

(m::WrappedSimpleChain)(x) = m.model(x, m.params)

loss(m, x, y) = mse(m(x), y)

let
  sc = SimpleChain(8, TurboDense(identity, 8))
  p = SimpleChains.init_params(sc)
  model = WrappedSimpleChain(sc, p)
  opt_state = Optimisers.setup(Optimisers.ADAM(), model)

  x = randn(Float32, 8, 1)
  y = 0.75x
  for i in 1:10
    grads, = gradient(model) do m
      loss(m, x, y)
    end
    opt_state, model = Optimisers.update!(opt_state, model, grads)
    @info i, loss(model, x, y)
  end
end

So not only does the compatibility appear to be there, but it looks pretty future-proof as well!

@chriselrod
Copy link
Contributor

Note that SimpleChains memory management is not threadsafe, and thus requires manual management.

The train_*! methods manage memory manually so they can multithread safely.

But simple calls/gradient calculations are (a) not multithreaded and (b) manually batching and calling with Threads.@threads, Threads.@spawn, or Polyester.@batch will give corrupted/wrong results.
You can work around this via having one SimpleChain per task.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants