-
-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU kernels for optimizers #178
Comments
The kernel in NerfUtils.jl fuses several operations into a single kernel, while Optimisers split it up into 4 (if counting actual parameter update). For smaller arrays the benefit is negligible, but for something like MWE: using AMDGPU
using BenchmarkTools
using KernelAbstractions
using Flux
using NerfUtils
function main()
x = AMDGPU.rand(Float32, 100 * 1024^2)
dx = AMDGPU.rand(Float32, 100 * 1024^2)
kab = get_backend(x)
opt_1 = NerfUtils.Adam(kab, x)
opt_2 = Flux.Optimisers.Adam()
state = Flux.Optimisers.init(opt_2, x)
@btime AMDGPU.@sync NerfUtils.step!($opt_1, $x, $dx; dispose=false)
@btime AMDGPU.@sync begin
ns, nx = Flux.Optimisers.apply!($opt_2, $state, $x, $dx)
$x .-= nx
end
return
end Timings: julia> main()
6.168 ms (395 allocations: 10.13 KiB)
13.161 ms (339 allocations: 9.09 KiB) |
The reason Optimisers.jl rules are written the way they are is because we have to balance a few things. To demonstrate them, let's look at the implementation of Lines 219 to 221 in c2ae321
Those are the constraints. My thoughts on where to proceed are that we need a design which addresses most of them. Priority-wise, my ranking would be 1) maintaining laziness when rules are composed, 2) maximizing code reuse with non GPU arrays, and 3) lowering the barrier of entry so people don't have to understand all of KernelAbstractions to get started. This all seems doable, but would require a champion to flesh out a design and push it. |
Motivation and description
Wondering what kind of speedup can be achieved by writing GPU kernels for optimizers.
Take a look at @pxl-th's implementation of Adam below
https://github.com/JuliaNeuralGraphics/NerfUtils.jl/blob/main/src/nn/adam.jl#L100-L117
Possible Implementation
No response
The text was updated successfully, but these errors were encountered: