Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing new apply interface for Flux.Chain #5

Merged
merged 8 commits into from
Mar 26, 2023

Conversation

mkschleg
Copy link
Contributor

New Feature:

Implemented new apply functionality which returns a newly constructed chain as well as the output of applying the input to the current chain. This is in relation to the recent conversation for updating the recurrent network interface. The apply was described by @ToucheSir.

Most of the code is a modification of the _applychain functionality from Flux.jl. The general direction is to likely interface with Accessors.jl. Some minor tests were added, but we will want to include many of the layer types as we expand the feature.

PR Checklist

  • Tests are added
  • Documentation, if applicable

src/chain.jl Outdated Show resolved Hide resolved
src/chain.jl Outdated Show resolved Hide resolved
src/chain.jl Outdated Show resolved Hide resolved
src/chain.jl Outdated Show resolved Hide resolved
src/chain.jl Outdated Show resolved Hide resolved
test/chain.jl Outdated Show resolved Hide resolved
mkschleg and others added 3 commits March 22, 2023 10:13
Batch suggestions from review.

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
@mkschleg
Copy link
Contributor Author

I added the pullback for the vector chain. It likely doesn't work on GPUs yet, and may not be implemented in the most efficient way possible. But it passes the simple test comparing to the other interface.

@mkschleg
Copy link
Contributor Author

mkschleg commented Mar 23, 2023

I also decided to unify _apply with _apply_to_layer because I initially forgot about chains which can be composed of other chains through the parallel and other interfaces.

Edit: This doesn't solve the issue necessarily. I'm not quite sure how to maintain backwards compat for Chains within Chains here. Maybe a separate dispatch for _apply for Flux.Chains?

@mkschleg
Copy link
Contributor Author

mkschleg commented Mar 24, 2023

Also, implementing the non-mutating recur for this interface requires us to make a decision about how to handle time-series inputs for a chain. This is because [chain(x) for x in xs] will no longer work (for hopefully clear reasons). Instead, we will need a custom map function (which maybe _apply can dispatch on?). In equinox, and other jax libraries this is handled through jax.vmap, I believe. I could see this kind of function going two ways:

  1. We handle this at the model level, and run the forward pass for the entire model on each time-step of the input.
  2. We handle this at the layer level, and run the forward pass for the entire input for each layer sequentially.

I could see 2 offering more optimization options. But I'm not sure how well it will interact with the rest of Flux (especially future planned stuff).

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have the 3D array API for Recur which tackles some of this, but it's definitely lacking and I'd very much like to find a nicer solution. JAX RNNs generally make use of lax.scan for sequence handling, so that immediately comes to mind as inspiration.

We can worry about all that after merging this PR, though :)

@ToucheSir ToucheSir merged commit 19a9205 into FluxML:master Mar 26, 2023
@mkschleg
Copy link
Contributor Author

@ToucheSir For sure. And yes it is, but the 3d api was never meant to be fully featured from what I remember when we were working on it. I'll put up another PR for the non-mutating recur when I make some more progress.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants