Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradients for cumsum and cumprod #282

Closed
MikeInnes opened this issue Jul 30, 2019 · 12 comments
Closed

Gradients for cumsum and cumprod #282

MikeInnes opened this issue Jul 30, 2019 · 12 comments

Comments

@MikeInnes
Copy link
Member

MikeInnes commented Jul 30, 2019

Including supporting dims. Currently these break. We might have some gradients in Tracker already.

@mcabbott
Copy link
Member

I had a go in FluxML/Flux.jl#524, which is a bit messy, but the summary is:

  • cumsum is easy, just Δ -> (reverse(cumsum(reverse(Δ,dims=d),dims=d),dims=d),). However reverse(x;dims) doesn't work on CuArrays right now.

  • For cumprod I wrote a function which was careful about zero entries. For the dims case you can apply it like mapslices(∇cumprod, x,p,Δ; dims=d), except you need to write that out (perhaps with eachslice now). I don't know a neat way to do this more generically.

I'd be happy to tidy up these CPU versions, if that sounds good? On the GPU cumsum(::CuVector) can easily work, but I don't promise to do the rest.

@kraftpunk97-zz
Copy link

@mcabbott I would like to take a shot at this. Why doesn't reverse(x; dims) work with CuArrays?

@mcabbott
Copy link
Member

mcabbott commented Aug 6, 2019

Someone just has to write it, see JuliaGPU/CuArrays.jl#299 . Now that I check, I see that cumsum(cu(rand(2,3)), dims=1) is also an error, so someone would have to write that too. Looks like the CuVector case is from JuliaGPU/CuArrays.jl#285 .

@kraftpunk97-zz
Copy link

kraftpunk97-zz commented Aug 6, 2019

cumsum seems fine to me...

julia> using CuArrays

julia> cumsum(cu(rand(2, 3)), dims=1)
2×3 CuArray{Float32,2}:
 0.172419  0.3915    0.629144
 0.931309  0.598612  1.28012 

reverse is a problem though. I'm working on it.

EDIT - Okay, I did some digging and here's what I found...

julia> A = rand(2^10, 2^10);

julia> A_ = cu(A);

julia> @btime cumsum($A, dims=1)
  1.991 ms (2 allocations: 8.00 MiB)
1024×1024 Array{Float64,2}:
   0.261473    0.0214155    0.8109   …    0.313978    0.476666    0.341399
   0.968331    0.136674     1.10433       0.605896    0.549707    0.871519
   1.21307     0.339572     1.50853       0.834931    1.04689     1.13798 
   1.5554      0.882411     2.28349       1.47903     1.11472     1.26001 
   2.51751     1.01438      2.3474        2.44018     1.51983     1.46166 
   3.0148      1.94333      2.41744  …    3.3687      1.7526      1.59128 
   3.98639     2.02868      3.25437       4.25185     1.75492     1.74776 
   4.19211     2.58001      3.48065       5.18552     2.1872      2.65383 
   4.4275      3.4754       3.89582       5.43368     2.956       3.41412 
   4.47599     3.93059      4.2206        5.80521     3.44861     3.76436 
   ⋮                                 ⋱                                    
 506.243     504.758      495.649    …  509.346     517.034     498.596   
 507.106     505.361      496.597       510.228     517.32      499.541   
 507.394     506.133      497.406       510.956     517.808     499.77    
 507.762     506.796      497.808       511.655     518.407     499.938   
 507.882     507.7        497.953       511.88      518.932     500.92    
 508.478     508.363      498.283    …  512.814     519.305     501.593   
 508.824     509.052      499.033       512.825     520.24      501.641   
 509.682     509.743      499.191       513.801     520.37      502.183   
 510.629     510.513      499.332       513.926     520.684     503.14    

julia> @btime cumsum($A_, dims=1)
  50.284 s (8388364 allocations: 320.00 MiB)
1024×1024 CuArray{Float32,2}:
   0.261473    0.0214155    0.8109   …    0.313978    0.476666    0.341399
   0.968331    0.136674     1.10433       0.605896    0.549707    0.871519
   1.21307     0.339572     1.50853       0.834931    1.04689     1.13798 
   1.5554      0.882411     2.28349       1.47903     1.11472     1.26001 
   2.51751     1.01438      2.3474        2.44018     1.51983     1.46166 
   3.0148      1.94333      2.41744  …    3.3687      1.75259     1.59128 
   3.98639     2.02868      3.25436       4.25185     1.75492     1.74776 
   4.19211     2.58001      3.48065       5.18552     2.1872      2.65383 
   4.4275      3.4754       3.89582       5.43368     2.956       3.41412 
   4.47599     3.93059      4.2206        5.80521     3.44861     3.76436 
   ⋮                                 ⋱                                    
 506.243     504.758      495.649    …  509.346     517.034     498.595   
 507.106     505.361      496.597       510.227     517.32      499.541   
 507.394     506.133      497.406       510.955     517.808     499.77    
 507.762     506.796      497.807       511.654     518.407     499.938   
 507.882     507.7        497.953       511.88      518.932     500.92    
 508.479     508.363      498.283    …  512.814     519.305     501.593   
 508.824     509.052      499.033       512.825     520.241     501.641   
 509.682     509.743      499.191       513.801     520.37      502.183   
 510.629     510.513      499.332       513.926     520.685     503.139

@mcabbott cumsum is very slow with CuArray. Is this what you're talking about?

@mcabbott
Copy link
Member

mcabbott commented Aug 6, 2019

Yes. I should have said, it's an error after CuArrays.allowscalar(false). I think that without this, you are getting some fallback method which goes via the CPU.

@kraftpunk97-zz
Copy link

kraftpunk97-zz commented Aug 6, 2019

So writing a customised kernel functions for cumsum and reverse should solve this problem, yes?

@kraftpunk97-zz
Copy link

kraftpunk97-zz commented Aug 10, 2019

@mcabbott Over the last few days, I tried to implement a version of reverse that works with CuArray of arbitrary dimensions, and I think what I got is efficient enough for our purpose. I'm posting my results. Any other requirements for cumsum? My concern is that since CuArrays changes Int and all other float data types to Float32, this could mean that gradient will be dropped.

julia> shape_ = [2^10, 2^10];

julia> a = reshape(Vector(1:prod(shape_)), shape_...);

julia> a_ = cu(a);

julia> b = similar(a);

julia> b_ = similar(a_);

julia> @btime b = reverse(a, dims=2)
  2.099 ms (6 allocations: 8.00 MiB)
1024×1024 Array{Int64,2}:
 1047553  1046529  1045505  1044481  1043457  1042433  …   9217  8193  7169  6145  5121  4097  3073  2049  1025     1
 1047554  1046530  1045506  1044482  1043458  1042434      9218  8194  7170  6146  5122  4098  3074  2050  1026     2
 1047555  1046531  1045507  1044483  1043459  1042435      9219  8195  7171  6147  5123  4099  3075  2051  1027     3
 1047556  1046532  1045508  1044484  1043460  1042436      9220  8196  7172  6148  5124  4100  3076  2052  1028     4
 1047557  1046533  1045509  1044485  1043461  1042437      9221  8197  7173  6149  5125  4101  3077  2053  1029     5
 1047558  1046534  1045510  1044486  1043462  1042438  …   9222  8198  7174  6150  5126  4102  3078  2054  1030     6
       ⋮                                            ⋮  ⋱            ⋮                             ⋮                  
 1048572  1047548  1046524  1045500  1044476  1043452     10236  9212  8188  7164  6140  5116  4092  3068  2044  1020
 1048573  1047549  1046525  1045501  1044477  1043453  …  10237  9213  8189  7165  6141  5117  4093  3069  2045  1021
 1048574  1047550  1046526  1045502  1044478  1043454     10238  9214  8190  7166  6142  5118  4094  3070  2046  1022
 1048575  1047551  1046527  1045503  1044479  1043455     10239  9215  8191  7167  6143  5119  4095  3071  2047  1023
 1048576  1047552  1046528  1045504  1044480  1043456     10240  9216  8192  7168  6144  5120  4096  3072  2048  1024

julia> @btime my_reverse(a_, b_, 2)
  93.433 μs (33 allocations: 1.13 KiB)

julia> b_
1024×1024 CuArray{Float32,2}:
 1.04755e6  1.04653e6  1.04551e6  1.04448e6  1.04346e6  1.04243e6  …  6145.0  5121.0  4097.0  3073.0  2049.0  1025.0     1.0
 1.04755e6  1.04653e6  1.04551e6  1.04448e6  1.04346e6  1.04243e6     6146.0  5122.0  4098.0  3074.0  2050.0  1026.0     2.0
 1.04756e6  1.04653e6  1.04551e6  1.04448e6  1.04346e6  1.04244e6     6147.0  5123.0  4099.0  3075.0  2051.0  1027.0     3.0
 1.04756e6  1.04653e6  1.04551e6  1.04448e6  1.04346e6  1.04244e6     6148.0  5124.0  4100.0  3076.0  2052.0  1028.0     4.0
 1.04756e6  1.04653e6  1.04551e6  1.04449e6  1.04346e6  1.04244e6     6149.0  5125.0  4101.0  3077.0  2053.0  1029.0     5.0
 1.04756e6  1.04653e6  1.04551e6  1.04449e6  1.04346e6  1.04244e6  …  6150.0  5126.0  4102.0  3078.0  2054.0  1030.0     6.0
 ⋮                                                      ⋮          ⋱                             ⋮                          
 1.04857e6  1.04755e6  1.04652e6  1.0455e6   1.04448e6  1.04345e6     7164.0  6140.0  5116.0  4092.0  3068.0  2044.0  1020.0
 1.04857e6  1.04755e6  1.04653e6  1.0455e6   1.04448e6  1.04345e6  …  7165.0  6141.0  5117.0  4093.0  3069.0  2045.0  1021.0
 1.04857e6  1.04755e6  1.04653e6  1.0455e6   1.04448e6  1.04345e6     7166.0  6142.0  5118.0  4094.0  3070.0  2046.0  1022.0
 1.04858e6  1.04755e6  1.04653e6  1.0455e6   1.04448e6  1.04346e6     7167.0  6143.0  5119.0  4095.0  3071.0  2047.0  1023.0
 1.04858e6  1.04755e6  1.04653e6  1.0455e6   1.04448e6  1.04346e6     7168.0  6144.0  5120.0  4096.0  3072.0  2048.0  1024.0

@mcabbott
Copy link
Member

That sounds promising. Ideally this and a similar cumsum() would become part of CuArrays. Be warned that you may need to insert CuArrays.@sync to get meaningful timing, I'm not sure.

Re types, the function cu is opinionated, but e.g. CuArray(rand(5)) contains Float64 and integers should work too. Zygote does not mind if the reverse pass produces different types if needed, but this shouldn't be the case here.

@kraftpunk97-zz
Copy link

Right. My bad. After making a few changes to my code, and further levelling the playing field, @btime b = reverse(a, dims=2) on a CPU, takes about 740microseconds; and my implementation on a GPU takes about 130 microseconds.

@AlexLewandowski
Copy link

Is there any update on this? Currently I get ERROR: Mutating arrays is not supported

bors bot added a commit that referenced this issue Feb 27, 2020
284: adjoint for cumsum r=CarloLucibello a=mcabbott

The easy half of #282

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-authored-by: Michael Abbott <me@pseudomac>
@Red-Portal
Copy link

Red-Portal commented Aug 15, 2023

Seems like the rule is not properly kicking in?

using Zygote, CUDA
gradient(x -> x |> cumprod |> sum , CUDA.randn(3) )
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103
  [3] getindex
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/indexing.jl:9 [inlined]
  [4] ∇cumprod!
    @ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:405 [inlined]
  [5] ∇cumprod(x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, dy::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, y::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ ChainRules ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:395
  [6] #1693
    @ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:341 [inlined]
  [7] unthunk
    @ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204 [inlined]
  [8] unthunk
    @ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:237 [inlined]
  [9] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:110 [inlined]
 [10] map
    @ ./tuple.jl:274 [inlined]
 [11] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:111 [inlined]
 [12] ZBack
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:211 [inlined]
 [13] Pullback
    @ ./operators.jl:907 [inlined]
 [14] Pullback
    @ ./REPL[7]:1 [inlined]
 [15] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#31#32", CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(|>), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, typeof(cumprod)}, Tuple{Zygote.ZBack{ChainRules.var"#cumprod_pullback_1#1694"{Int64, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, Zygote.Pullback{Tuple{typeof(|>), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, typeof(sum)}, Tuple{Zygote.var"#4229#back#1457"{Zygote.var"#1453#1456"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:45
 [16] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:97
 [17] top-level scope
    @ REPL[7]:1
 [18] top-level scope
    @ ~/.julia/packages/CUDA/tVtYo/src/initialization.jl:185

@ToucheSir
Copy link
Member

ToucheSir commented Aug 15, 2023

The rule is being hit correctly per the stacktrace, but I don't think a GPU-compatible pullback was ever implemented on the ChainRules side. You'd want to open an issue with them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants