Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU kernels for optimizers #178

Open
vpuri3 opened this issue Jul 8, 2024 · 2 comments
Open

GPU kernels for optimizers #178

vpuri3 opened this issue Jul 8, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@vpuri3
Copy link

vpuri3 commented Jul 8, 2024

Motivation and description

Wondering what kind of speedup can be achieved by writing GPU kernels for optimizers.

Take a look at @pxl-th's implementation of Adam below

https://github.com/JuliaNeuralGraphics/NerfUtils.jl/blob/main/src/nn/adam.jl#L100-L117

Possible Implementation

No response

@pxl-th
Copy link
Member

pxl-th commented Jul 9, 2024

The kernel in NerfUtils.jl fuses several operations into a single kernel, while Optimisers split it up into 4 (if counting actual parameter update).

For smaller arrays the benefit is negligible, but for something like ~400+ MB it is around ~2x faster.

MWE:

using AMDGPU
using BenchmarkTools
using KernelAbstractions
using Flux
using NerfUtils

function main()
    x = AMDGPU.rand(Float32, 100 * 1024^2)
    dx = AMDGPU.rand(Float32, 100 * 1024^2)

    kab = get_backend(x)

    opt_1 = NerfUtils.Adam(kab, x)
    opt_2 = Flux.Optimisers.Adam()
    state = Flux.Optimisers.init(opt_2, x)

    @btime AMDGPU.@sync NerfUtils.step!($opt_1, $x, $dx; dispose=false)

    @btime AMDGPU.@sync begin
        ns, nx = Flux.Optimisers.apply!($opt_2, $state, $x, $dx)
        $x .-= nx
    end
    return
end

Timings:

julia> main()
  6.168 ms (395 allocations: 10.13 KiB)
  13.161 ms (339 allocations: 9.09 KiB)

@ToucheSir ToucheSir added the enhancement New feature or request label Jul 9, 2024
@ToucheSir
Copy link
Member

The reason Optimisers.jl rules are written the way they are is because we have to balance a few things. To demonstrate them, let's look at the implementation of Adam:

Optimisers.jl/src/rules.jl

Lines 219 to 221 in c2ae321

@.. mt = β[1] * mt + (1 - β[1]) * dx
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
dx′ = @lazy mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η

  1. Broad array type compatibility: that @.. macro is not a typo. It's actually a custom version of @. which will write in-place where possible and return a new array for immutable array types. Custom kernels only work with mutable array types, and a limited number of them at that.
  2. Deferring work where possible: "...while Optimisers split it up into 4 (if counting actual parameter update)" is off by one, because the @lazy means dx′ is a Broadcasted instead of a materialized array. We do this to ensure better fusion with subsequent steps (think how AdamW does Adam + WeightDecay), as well as fusion with the final parameter update step. Writing a standalone GPU kernel for each AbstractRule would mean we lose out on this fusion.
  3. Legibility and ease of entry: most people who contribute rules to Optimisers.jl are not super familar with writing GPU code. Our current system for writing rules seems to be pretty accessible, since most of the work is translating statements of math -> statements of array-level Julia code. Unless we want to make follow-up PRs for GPU kernels every time a new rule is added, we'll want a system which lowers the barrier to entry somehow.

Those are the constraints. My thoughts on where to proceed are that we need a design which addresses most of them. Priority-wise, my ranking would be 1) maintaining laziness when rules are composed, 2) maximizing code reuse with non GPU arrays, and 3) lowering the barrier of entry so people don't have to understand all of KernelAbstractions to get started. This all seems doable, but would require a champion to flesh out a design and push it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants