Closed
Description
The adjoint is not implemented for pooling function with custom c
as inputs. For this case, Zygote produces nothing for the gradient.
Usually, we don't specify c
, but It will be an issue if the last node in the array is disconnected: specifying c
makes sure the dimensions are consistent.
example:
cluster = [1; 2; 2; 1]
z = ones(4, 4)
test(z) = sum( sumpool(cluster, z) )
test'(z) # returns gradient
cluster = [1; 2; 2; 1]
z = ones(4, 4)
test(z) = sum( sumpool(cluster, z, 2))
test'(z) # returns nothing
proposed fix:
I wonder if adding the following would fix it?
@adjoint sumpool(cluster::AbstractArray{Int}, X::AbstractArray{T}, c::Int) where {T<:Real} =
sumpool(cluster, X, c), Δ -> (nothing, gather(zero(Δ)+Δ, cluster), nothing)
Metadata
Metadata
Assignees
Labels
No labels