-
-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Testing new apply interface for Flux.Chain #5
Conversation
Batch suggestions from review. Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
I added the pullback for the vector chain. It likely doesn't work on GPUs yet, and may not be implemented in the most efficient way possible. But it passes the simple test comparing to the other interface. |
I also decided to unify _apply with _apply_to_layer because I initially forgot about chains which can be composed of other chains through the parallel and other interfaces. Edit: This doesn't solve the issue necessarily. I'm not quite sure how to maintain backwards compat for Chains within Chains here. Maybe a separate dispatch for |
Also, implementing the non-mutating recur for this interface requires us to make a decision about how to handle time-series inputs for a chain. This is because
I could see 2 offering more optimization options. But I'm not sure how well it will interact with the rest of Flux (especially future planned stuff). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do have the 3D array API for Recur
which tackles some of this, but it's definitely lacking and I'd very much like to find a nicer solution. JAX RNNs generally make use of lax.scan
for sequence handling, so that immediately comes to mind as inspiration.
We can worry about all that after merging this PR, though :)
@ToucheSir For sure. And yes it is, but the 3d api was never meant to be fully featured from what I remember when we were working on it. I'll put up another PR for the non-mutating recur when I make some more progress. |
New Feature:
Implemented new
apply
functionality which returns a newly constructed chain as well as the output of applying the input to the current chain. This is in relation to the recent conversation for updating the recurrent network interface. The apply was described by @ToucheSir.Most of the code is a modification of the
_applychain
functionality from Flux.jl. The general direction is to likely interface with Accessors.jl. Some minor tests were added, but we will want to include many of the layer types as we expand the feature.PR Checklist