-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradients for cumsum and cumprod #282
Comments
I had a go in FluxML/Flux.jl#524, which is a bit messy, but the summary is:
I'd be happy to tidy up these CPU versions, if that sounds good? On the GPU |
@mcabbott I would like to take a shot at this. Why doesn't |
Someone just has to write it, see JuliaGPU/CuArrays.jl#299 . Now that I check, I see that |
EDIT - Okay, I did some digging and here's what I found...
@mcabbott |
Yes. I should have said, it's an error after |
So writing a customised kernel functions for |
@mcabbott Over the last few days, I tried to implement a version of
|
That sounds promising. Ideally this and a similar Re types, the function |
Right. My bad. After making a few changes to my code, and further levelling the playing field, |
Is there any update on this? Currently I get |
Seems like the rule is not properly kicking in? using Zygote, CUDA
gradient(x -> x |> cumprod |> sum , CUDA.randn(3) ) ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103
[3] getindex
@ ~/.julia/packages/GPUArrays/5XhED/src/host/indexing.jl:9 [inlined]
[4] ∇cumprod!
@ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:405 [inlined]
[5] ∇cumprod(x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, dy::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, y::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ ChainRules ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:395
[6] #1693
@ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:341 [inlined]
[7] unthunk
@ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204 [inlined]
[8] unthunk
@ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:237 [inlined]
[9] wrap_chainrules_output
@ ~/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:110 [inlined]
[10] map
@ ./tuple.jl:274 [inlined]
[11] wrap_chainrules_output
@ ~/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:111 [inlined]
[12] ZBack
@ ~/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:211 [inlined]
[13] Pullback
@ ./operators.jl:907 [inlined]
[14] Pullback
@ ./REPL[7]:1 [inlined]
[15] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#31#32", CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(|>), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, typeof(cumprod)}, Tuple{Zygote.ZBack{ChainRules.var"#cumprod_pullback_1#1694"{Int64, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, Zygote.Pullback{Tuple{typeof(|>), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, typeof(sum)}, Tuple{Zygote.var"#4229#back#1457"{Zygote.var"#1453#1456"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:45
[16] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:97
[17] top-level scope
@ REPL[7]:1
[18] top-level scope
@ ~/.julia/packages/CUDA/tVtYo/src/initialization.jl:185 |
The rule is being hit correctly per the stacktrace, but I don't think a GPU-compatible pullback was ever implemented on the ChainRules side. You'd want to open an issue with them. |
Including supporting
dims
. Currently these break. We might have some gradients in Tracker already.The text was updated successfully, but these errors were encountered: