Skip to content

Commit

Permalink
Merge pull request #297 from yuehhua/scatter-ad
Browse files Browse the repository at this point in the history
Add rrule for scatter
  • Loading branch information
CarloLucibello authored May 14, 2021
2 parents 908ec40 + 4ccd332 commit ed98dd8
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 0 deletions.
90 changes: 90 additions & 0 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,93 @@ function scatter(op::typeof(mean),
fill!(dst, Base.reduce_empty(+, FT))
scatter!(op, dst, src, idx)
end

## Gradients

opname(::typeof(+)) = :add
opname(::typeof(-)) = :sub
opname(::typeof(*)) = :mul
opname(::typeof(/)) = :div


∇scatter_dst!(op, Δ, dst, y) = Δ

# function ∇scatter_dst!(op::Union{typeof(max),typeof(min)}, Δ, dst, y)
# mask_y = (dst .== op.(dst, y))
# mask_y .* Δ
# end

modify_src(::typeof(+), X) = X
modify_src(::typeof(-), X) = -X
modify_src(::typeof(*), X, Y) = X
modify_src(::typeof(/), X, Y) = -X ./ Y.^2

∇src_init!(Δ, idx) = gather(Δ, idx)
∇src_init!(Δ, dst, idx) = gather(dst, idx) .* ∇src_init!(Δ, idx)
∇src_init(Δ, idx) = gather(Δ, idx)

∇scatter_src!(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = modify_src(op, ∇src_init!(Δ, idx))
∇scatter_src(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = modify_src(op, ∇src_init(Δ, idx))

function ∇scatter_src!(op::Union{typeof(*),typeof(/)}, Δ, dst,
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
Δsrc = modify_src(op, ∇src_init!(Δ, dst, idx), src)
rev_idx = reverse_indices(idx)
for k = CartesianIndices(idx)
inds = filter(x -> x != k, rev_idx[idx[k]])
for i = CartesianIndices(axes(src)[1:dims])
Δsrc[i, k] *= prod(j -> src[i, j], inds)
end
end
Δsrc
end

function ∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
Δsrc = modify_src(op, ∇src_init(Δ, idx), src)
rev_idx = reverse_indices(idx)
for k = CartesianIndices(idx)
inds = filter(x -> x != k, rev_idx[idx[k]])
for i = CartesianIndices(axes(src)[1:dims])
Δsrc[i, k] = op(Δsrc[i, k], prod(j -> src[i, j], inds))
end
end
Δsrc
end

# ∇scatter_src!(op::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = (src .== op.(src, gather(dst, idx))) .* ∇src_init!(Δ, idx)
# ∇scatter_src(op::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = (src .== op.(src, gather(dst, idx))) .* ∇src_init(Δ, idx)

∇scatter_src!(::typeof(mean), Δ, idx, dims) = divide_by_counts!(∇src_init!(Δ, idx), idx, dims)
∇scatter_src(::typeof(mean), Δ, idx, dims) = divide_by_counts!(∇src_init(Δ, idx), idx, dims)


function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
y = scatter!(op, copy(dst), src, idx)
scatter!_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_dst!(op, Δ, dst, y), ∇scatter_src!(op, Δ, dst, src, idx), DoesNotExist())
y, scatter!_pullback
end

function rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray)
y = scatter(op, src, idx)
scatter_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_src(op, Δ, y, src, idx), DoesNotExist())
y, scatter_pullback
end

function rrule(::typeof(scatter!), op::typeof(mean), dst::AbstractArray, src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
y = scatter!(op, copy(dst), src, idx)
scatter!_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_dst!(op, Δ, dst, y), ∇scatter_src!(op, Δ, idx, dims), DoesNotExist())
y, scatter!_pullback
end

function rrule(::typeof(scatter), op::typeof(mean), src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
y = scatter(op, src, idx)
scatter_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_src(op, Δ, idx, dims), DoesNotExist())
y, scatter_pullback
end
26 changes: 26 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,29 @@ The maximum of each dimension in the element is computed.
maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )
maximum_dims(dims::AbstractArray{NTuple{N, T}}) where {N,T} = ntuple(i -> maximum(x->x[i], dims), N)
maximum_dims(dims::AbstractArray{CartesianIndex{N}}) where {N} = ntuple(i -> maximum(x->x[i], dims), N)

function reverse_indices(idx::Array{T}) where T
rev = Dict{T,Vector{CartesianIndex}}()
for (ind, val) = pairs(idx)
rev[val] = get(rev, val, CartesianIndex[])
push!(rev[val], ind)
end
rev
end

function count_indices(idx::AbstractArray)
counts = zero.(idx)
for i = unique(idx)
counts += sum(idx.==i) * (idx.==i)
end
return counts
end

function divide_by_counts!(xs, idx::AbstractArray, dims)
colons = Base.ntuple(_->Colon(), dims)
counts = count_indices(idx)
for i = CartesianIndices(counts)
view(xs, colons..., i) ./= counts[i]
end
return xs
end
29 changes: 29 additions & 0 deletions test/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,32 @@ types = [UInt8, UInt16, UInt32, UInt64, UInt128,
idx = [1 2 3 4; 4 2 1 3; 6 7 8 9]
@test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx)
end

@testset "∇scatter" begin
T = Float64
@testset "∂dst" begin
for op in (+, -, *, /)
# TODO: get max, min pass tests
gradtest(xs -> scatter!(op, copy(xs), srcs[(0, true)], idxs[:int]), T.(dsts[0]))
gradtest(xs -> scatter!(op, copy(xs), srcs[(1, true)], idxs[:int]), T.(dsts[1]))
end
gradtest(xs -> scatter!(mean, copy(xs), srcs[(0, true)], idxs[:int]), T.(dsts[0]))
gradtest(xs -> scatter!(mean, copy(xs), srcs[(1, true)], idxs[:int]), T.(dsts[1]))
end

@testset "∂src" begin
for op in (+, -, *, /)
# TODO: get max, min pass tests
gradtest(xs -> scatter!(op, T.(dsts[0]), xs, idxs[:int]), T.(srcs[(0, true)]))
gradtest(xs -> scatter!(op, T.(dsts[1]), xs, idxs[:int]), T.(srcs[(1, true)]))

gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(0, false)]))
gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(1, false)]))
end
gradtest(xs -> scatter!(mean, T.(dsts[0]), xs, idxs[:int]), T.(srcs[(0, true)]))
gradtest(xs -> scatter!(mean, T.(dsts[1]), xs, idxs[:int]), T.(srcs[(1, true)]))

gradtest(xs -> scatter(mean, xs, idxs[:int]), T.(srcs[(0, false)]))
gradtest(xs -> scatter(mean, xs, idxs[:int]), T.(srcs[(1, false)]))
end
end

0 comments on commit ed98dd8

Please sign in to comment.