diff --git a/src/scatter.jl b/src/scatter.jl index f637b6bed..2be655cf5 100644 --- a/src/scatter.jl +++ b/src/scatter.jl @@ -155,3 +155,93 @@ function scatter(op::typeof(mean), fill!(dst, Base.reduce_empty(+, FT)) scatter!(op, dst, src, idx) end + +## Gradients + +opname(::typeof(+)) = :add +opname(::typeof(-)) = :sub +opname(::typeof(*)) = :mul +opname(::typeof(/)) = :div + + +∇scatter_dst!(op, Δ, dst, y) = Δ + +# function ∇scatter_dst!(op::Union{typeof(max),typeof(min)}, Δ, dst, y) +# mask_y = (dst .== op.(dst, y)) +# mask_y .* Δ +# end + +modify_src(::typeof(+), X) = X +modify_src(::typeof(-), X) = -X +modify_src(::typeof(*), X, Y) = X +modify_src(::typeof(/), X, Y) = -X ./ Y.^2 + +∇src_init!(Δ, idx) = gather(Δ, idx) +∇src_init!(Δ, dst, idx) = gather(dst, idx) .* ∇src_init!(Δ, idx) +∇src_init(Δ, idx) = gather(Δ, idx) + +∇scatter_src!(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = modify_src(op, ∇src_init!(Δ, idx)) +∇scatter_src(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = modify_src(op, ∇src_init(Δ, idx)) + +function ∇scatter_src!(op::Union{typeof(*),typeof(/)}, Δ, dst, + src::AbstractArray{Tsrc,Nsrc}, + idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} + dims = Nsrc - Nidx + Δsrc = modify_src(op, ∇src_init!(Δ, dst, idx), src) + rev_idx = reverse_indices(idx) + for k = CartesianIndices(idx) + inds = filter(x -> x != k, rev_idx[idx[k]]) + for i = CartesianIndices(axes(src)[1:dims]) + Δsrc[i, k] *= prod(j -> src[i, j], inds) + end + end + Δsrc +end + +function ∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, + src::AbstractArray{Tsrc,Nsrc}, + idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} + dims = Nsrc - Nidx + Δsrc = modify_src(op, ∇src_init(Δ, idx), src) + rev_idx = reverse_indices(idx) + for k = CartesianIndices(idx) + inds = filter(x -> x != k, rev_idx[idx[k]]) + for i = CartesianIndices(axes(src)[1:dims]) + Δsrc[i, k] = op(Δsrc[i, k], prod(j -> src[i, j], inds)) + end + end + Δsrc +end + +# ∇scatter_src!(op::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = (src .== op.(src, gather(dst, idx))) .* ∇src_init!(Δ, idx) +# ∇scatter_src(op::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = (src .== op.(src, gather(dst, idx))) .* ∇src_init(Δ, idx) + +∇scatter_src!(::typeof(mean), Δ, idx, dims) = divide_by_counts!(∇src_init!(Δ, idx), idx, dims) +∇scatter_src(::typeof(mean), Δ, idx, dims) = divide_by_counts!(∇src_init(Δ, idx), idx, dims) + + +function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray) + y = scatter!(op, copy(dst), src, idx) + scatter!_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_dst!(op, Δ, dst, y), ∇scatter_src!(op, Δ, dst, src, idx), DoesNotExist()) + y, scatter!_pullback +end + +function rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray) + y = scatter(op, src, idx) + scatter_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_src(op, Δ, y, src, idx), DoesNotExist()) + y, scatter_pullback +end + +function rrule(::typeof(scatter!), op::typeof(mean), dst::AbstractArray, src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} + dims = Nsrc - Nidx + y = scatter!(op, copy(dst), src, idx) + scatter!_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_dst!(op, Δ, dst, y), ∇scatter_src!(op, Δ, idx, dims), DoesNotExist()) + y, scatter!_pullback +end + +function rrule(::typeof(scatter), op::typeof(mean), src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} + dims = Nsrc - Nidx + y = scatter(op, src, idx) + scatter_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_src(op, Δ, idx, dims), DoesNotExist()) + y, scatter_pullback +end diff --git a/src/utils.jl b/src/utils.jl index dc8e1a16e..a461f72df 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -14,3 +14,29 @@ The maximum of each dimension in the element is computed. maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), ) maximum_dims(dims::AbstractArray{NTuple{N, T}}) where {N,T} = ntuple(i -> maximum(x->x[i], dims), N) maximum_dims(dims::AbstractArray{CartesianIndex{N}}) where {N} = ntuple(i -> maximum(x->x[i], dims), N) + +function reverse_indices(idx::Array{T}) where T + rev = Dict{T,Vector{CartesianIndex}}() + for (ind, val) = pairs(idx) + rev[val] = get(rev, val, CartesianIndex[]) + push!(rev[val], ind) + end + rev +end + +function count_indices(idx::AbstractArray) + counts = zero.(idx) + for i = unique(idx) + counts += sum(idx.==i) * (idx.==i) + end + return counts +end + +function divide_by_counts!(xs, idx::AbstractArray, dims) + colons = Base.ntuple(_->Colon(), dims) + counts = count_indices(idx) + for i = CartesianIndices(counts) + view(xs, colons..., i) ./= counts[i] + end + return xs +end diff --git a/test/scatter.jl b/test/scatter.jl index 04789f732..b1186eb43 100644 --- a/test/scatter.jl +++ b/test/scatter.jl @@ -179,3 +179,32 @@ types = [UInt8, UInt16, UInt32, UInt64, UInt128, idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] @test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx) end + +@testset "∇scatter" begin + T = Float64 + @testset "∂dst" begin + for op in (+, -, *, /) + # TODO: get max, min pass tests + gradtest(xs -> scatter!(op, copy(xs), srcs[(0, true)], idxs[:int]), T.(dsts[0])) + gradtest(xs -> scatter!(op, copy(xs), srcs[(1, true)], idxs[:int]), T.(dsts[1])) + end + gradtest(xs -> scatter!(mean, copy(xs), srcs[(0, true)], idxs[:int]), T.(dsts[0])) + gradtest(xs -> scatter!(mean, copy(xs), srcs[(1, true)], idxs[:int]), T.(dsts[1])) + end + + @testset "∂src" begin + for op in (+, -, *, /) + # TODO: get max, min pass tests + gradtest(xs -> scatter!(op, T.(dsts[0]), xs, idxs[:int]), T.(srcs[(0, true)])) + gradtest(xs -> scatter!(op, T.(dsts[1]), xs, idxs[:int]), T.(srcs[(1, true)])) + + gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(0, false)])) + gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(1, false)])) + end + gradtest(xs -> scatter!(mean, T.(dsts[0]), xs, idxs[:int]), T.(srcs[(0, true)])) + gradtest(xs -> scatter!(mean, T.(dsts[1]), xs, idxs[:int]), T.(srcs[(1, true)])) + + gradtest(xs -> scatter(mean, xs, idxs[:int]), T.(srcs[(0, false)])) + gradtest(xs -> scatter(mean, xs, idxs[:int]), T.(srcs[(1, false)])) + end +end