-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inplace getindex rrule #240
Conversation
For Zygote to benifit it needs FluxML/Zygote.jl#603 |
86be30c
to
47bf517
Compare
Ok, seems like i need to do a bunch more methods of |
This broadly LGTM. It would be good to add some tests involving arrays with an element type that's not a number -- perhaps an array of arrays? |
Good Idea |
Something @willtebbutt pointed out to me the other day
We can use ChainRules's ability to define an inplace addition operation (the InplaceThunk) when defining the gradient of getindex.
Which would be much more efficient than e.g retuning a onehot dense matrix to be summed, and marginally more efficient than returning a onehot sparse matrix to be summed.
It would also mean maybe we would not need to say that the differential for a primal type depends on the primal type and the operation.
Though we still might want to.
Right now Zygote won't use ChainRules's inplace accumulation stuff, but idk how hard it would be to enable it.
I think it might be fine or it might (only) break nesting Zygote.
It would definitely require the deeper change over to ChainRules's types.
I am tempted to leave this WIP for a while and improve our abstractions, while trying to get Zygote to actually use this. (Though that is a bigger project as need Zygote to use ChainRules's trypes)