Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

missing adjoints for pooling functions with specified graph size #147

Closed
wwang2 opened this issue Dec 28, 2020 · 5 comments
Closed

missing adjoints for pooling functions with specified graph size #147

wwang2 opened this issue Dec 28, 2020 · 5 comments

Comments

@wwang2
Copy link

wwang2 commented Dec 28, 2020

The adjoint is not implemented for pooling function with custom c as inputs. For this case, Zygote produces nothing for the gradient.

Usually, we don't specify c, but It will be an issue if the last node in the array is disconnected: specifying c makes sure the dimensions are consistent.

example:

cluster = [1; 2; 2; 1]
z = ones(4, 4)
test(z) = sum( sumpool(cluster, z) )
test'(z) # returns gradient 
cluster = [1; 2; 2; 1]
z = ones(4, 4)
test(z) = sum( sumpool(cluster, z, 2))
test'(z) # returns nothing

proposed fix:

I wonder if adding the following would fix it?

@adjoint sumpool(cluster::AbstractArray{Int}, X::AbstractArray{T}, c::Int) where {T<:Real} =
    sumpool(cluster, X, c), Δ -> (nothing, gather(zero(Δ)+Δ, cluster), nothing)
@yuehhua
Copy link
Member

yuehhua commented Jan 6, 2021

Thank you for pointing this out. The pooling function here is not sound.
It is originally from scatter! functions in yuehhua/ScatterNNlib.jl.
Meanwhile, I was requested to move these fundamental functions into NNlib.jl and CUDA.jl. You can check it here yuehhua/ScatterNNlib.jl#32
So, recently I am working on FluxML/NNlib.jl#255 and the pooling functions here will become scatter (not in-place version).
The gradients for these functions will be implemented as well for Zygote.jl and ChainRules.jl.

@wwang2
Copy link
Author

wwang2 commented Jan 8, 2021

Thanks. That sounds great.

I also just want to comment on the performance of scatter operations. Not sure if it is helpful. Because the adjoint of view is sumpool so I want to check its implementation in Zygote.jl and the adjoint of view is implemented here. It turns out it is very efficient actually and does not seem to use simd.

computational example:
first, generate a random cluster and some arrays

cluster = rand(1:100,200);
h = rand(3, 100);
e = rand(3, 200)
@btime msg, pool = Zygote.pullback(view, h, :, cluster)
msg, pool = Zygote.pullback(view, h, :, cluster)
@btime pool(e)[1]
  1.808 μs (18 allocations: 736 bytes)
  1.867 μs (3 allocations: 2.55 KiB)

But if we use sumpool

@btime sumpool($cluster, $e, 100)
51.841 μs (801 allocations: 21.25 KiB)

Does this mean that there is more room to improve efficiency?

@yuehhua
Copy link
Member

yuehhua commented May 28, 2021

@wwang2 Current status is updated in yuehhua/ScatterNNlib.jl#32.
The new version of sumpool has been updated to be scatter(+, src, idx), while the cluster is renamed as idx and X is src.
New APIs are defined in NNlib.jl. In the future, GeometricFlux.jl will rely on NNlib.jl directly, instead of ScatterNNlib.jl.

@wwang2
Copy link
Author

wwang2 commented May 29, 2021

@yuehhua awesome, thanks for the nice contribution. I will surely use it for my work.

@wwang2 wwang2 closed this as completed May 31, 2021
@yuehhua
Copy link
Member

yuehhua commented Jul 2, 2021

@wwang2 All APIs for scatter and gather had been migrated to NNlib.jl and NNlibCUDA.jl. You may want to check it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants