-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The right way to implement rrule(broadcasted, f, args...)
#531
Comments
This seems to give sensible answers. But the tester might be confused if it compares to Note that you will need something like Zygote's |
So I guess the optimal approach would be to avoid ys = f.(xs...)
...
dxs = derivatives_given_output.(ys, f, xs...) # this already works for log() I already have a very similar mechanism in Yota (example), so it would be pretty natural to extend it with Regarding |
If the link works, this is my attempt at an unfused broadcasting rule. This uses No idea how stable you should regard I think the question of how much better a fused broadcast can be boils down to whether it can be made fast, and whether it can save memory. Pushing dual numbers through (like Zygote does) equates to quite a few copies, as does storing an array of closures. There might be ways to fuse some simple operations like |
Not at all stable. |
The link opens a new PR page and lets me pretend to be you 😄 But yes, I can see the code and all the different paths you've considered 👍 Given my simple use case of handling functions like
I believe there's nothing more to do in this issue, so closing it. Thanks for sharing all this great stuff with me! |
If you open the PR then you get to fix the bugs! One more question is whether rules for broadcasting should live here in ChainRules. If they are defined elsewhere through |
In some other discussion I proposed a generic implementation of
rrule
forbroadcasted
, a slightly modified version of which looks like this (usingrrule
instead ofrrule_via_ad
for simplicity here):Empirically, I can see that it works correctly at least in simple cases, e.g.:
But when I run
test_rrule(Broadcast.broadcasted, f, xs; check_inferred=false)
I get a strange error:Note 2 error messages:
and
Can you see a mistake in this implementation or is it just too complicated for
test_rrule()
to verify?The text was updated successfully, but these errors were encountered: