Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chainrule for CUDA reduction #666

Open
renatobellotti opened this issue Aug 20, 2022 · 4 comments
Open

Chainrule for CUDA reduction #666

renatobellotti opened this issue Aug 20, 2022 · 4 comments
Labels

Comments

@renatobellotti
Copy link

Hi,

I'd like to suggest including a rule for GPU reductions.

using Zygote

function my_loss(v)
    # This works:
    # l = sum(v)
    # This does not work:
    l = reduce(+, v)
    return l
end

v = cu([1., 2.])
Zygote.gradient(my_loss, v)

See also: FluxML/Zygote.jl#730 (comment)

@mcabbott mcabbott added enhancement New feature or request good first issue Good for newcomers labels Aug 20, 2022
@mcabbott
Copy link
Member

rrule(reduce, +, x; kw...) can just call rrule(sum, x; kw...) right?

@mcabbott mcabbott added missing rule and removed enhancement New feature or request labels Aug 20, 2022
@renatobellotti
Copy link
Author

renatobellotti commented Aug 20, 2022

Isn't the reduction implemented on the GPU? I don't know the details, but reducing on the GPU and then copying the result is certainly more efficient than copying the entire vector and reducing on the CPU.

@mcabbott
Copy link
Member

Sure. The rrule for sum just calls sum again on what it's given, for the forward pass, and thus uses the same GPU code as without AD. (And the reverse pass is written using broadcasting, which also works on the GPU.)

@renatobellotti
Copy link
Author

Nice!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants