Skip to content

Commit

Permalink
scatter rrules
Browse files Browse the repository at this point in the history
fix tests

fix test case

add test cases for scatter

fix rrule for scatter!

fix ∇scatter_src! for max/min

simplify pullback function name

fix bug

unify interfaces
  • Loading branch information
yuehhua committed Apr 8, 2021
1 parent 6905ca7 commit a4e9f88
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 0 deletions.
90 changes: 90 additions & 0 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,93 @@ function scatter(op::typeof(mean),
fill!(dst, Base.reduce_empty(+, FT))
scatter!(op, dst, src, idx)
end

## Gradients

opname(::typeof(+)) = :add
opname(::typeof(-)) = :sub
opname(::typeof(*)) = :mul
opname(::typeof(/)) = :div


∇scatter_dst!(op, Δ, dst, y) = Δ

function ∇scatter_dst!(op::Union{typeof(max),typeof(min)}, Δ, dst, y)
mask_y = (dst .== op.(dst, y))
mask_y .* Δ
end

modify_src(::typeof(+), X) = X
modify_src(::typeof(-), X) = -X
modify_src(::typeof(*), X, Y) = X
modify_src(::typeof(/), X, Y) = -X ./ Y.^2

∇src_init!(Δ, idx) = gather(Δ, idx)
∇src_init!(Δ, dst, idx) = gather(dst, idx) .* ∇src_init!(Δ, idx)
∇src_init(Δ, idx) = gather(Δ, idx)

∇scatter_src!(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = modify_src(op, ∇src_init!(Δ, idx))
∇scatter_src(op::Union{typeof(+),typeof(-)}, Δ, dst, src, idx) = modify_src(op, ∇src_init(Δ, idx))

function ∇scatter_src!(op::Union{typeof(*),typeof(/)}, Δ, dst,
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
Δsrc = modify_src(op, ∇src_init!(Δ, dst, idx), src)
rev_idx = reverse_indices(idx)
for k = CartesianIndices(idx)
inds = filter(x -> x != k, rev_idx[idx[k]])
for i = CartesianIndices(axes(src)[1:dims])
Δsrc[i, k] *= prod(j -> src[i, j], inds)
end
end
Δsrc
end

function ∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
Δsrc = modify_src(op, ∇src_init(Δ, idx), src)
rev_idx = reverse_indices(idx)
for k = CartesianIndices(idx)
inds = filter(x -> x != k, rev_idx[idx[k]])
for i = CartesianIndices(axes(src)[1:dims])
Δsrc[i, k] = op(Δsrc[i, k], prod(j -> src[i, j], inds))
end
end
Δsrc
end

∇scatter_src!(op::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = (src .== op.(src, gather(dst, idx))) .* ∇src_init!(Δ, idx)
∇scatter_src(op::Union{typeof(max),typeof(min)}, Δ, dst, src, idx) = (src .== op.(src, gather(dst, idx))) .* ∇src_init(Δ, idx)

∇scatter_src!(::typeof(mean), Δ, idx, dims) = divide_by_counts!(∇src_init!(Δ, idx), idx, dims)
∇scatter_src(::typeof(mean), Δ, idx, dims) = divide_by_counts!(∇src_init(Δ, idx), idx, dims)


function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray, dims)
y = scatter!(op, copy(dst), src, idx, dims)
scatter!_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_dst!(op, Δ, dst, y), ∇scatter_src!(op, Δ, dst, src, idx), DoesNotExist(), DoesNotExist())
y, scatter!_pullback
end

function rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray)
y = scatter(op, src, idx)
scatter_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_src(op, Δ, y, src, idx), DoesNotExist())
y, scatter_pullback
end

function rrule(::typeof(scatter!), op::typeof(mean), dst::AbstractArray, src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
y = scatter!(op, copy(dst), src, idx)
scatter!_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_dst!(op, Δ, dst, y), ∇scatter_src!(op, Δ, idx, dims), DoesNotExist())
y, scatter!_pullback
end

function rrule(::typeof(scatter), op::typeof(mean), src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
y = scatter(op, src, idx)
scatter_pullback(Δ) = (NO_FIELDS, NO_FIELDS, ∇scatter_src(op, Δ, idx, dims), DoesNotExist())
y, scatter_pullback
end
26 changes: 26 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,29 @@ maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )
function maximum_dims(dims::AbstractArray{<:Tuple})
Tuple(maximum(xs) for xs in zip(dims...))
end

function reverse_indices(idx::Array{T}) where T
rev = Dict{T,Vector{CartesianIndex}}()
for (ind, val) = pairs(idx)
rev[val] = get(rev, val, CartesianIndex[])
push!(rev[val], ind)
end
rev
end

function count_indices(idx::AbstractArray)
counts = zero.(idx)
for i = unique(idx)
counts += sum(idx.==i) * (idx.==i)
end
return counts
end

function divide_by_counts!(xs, idx::AbstractArray, dims)
colons = Base.ntuple(_->Colon(), dims)
counts = count_indices(idx)
for i = CartesianIndices(counts)
view(xs, colons..., i) ./= counts[i]
end
return xs
end
27 changes: 27 additions & 0 deletions test/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,30 @@ types = [UInt8, UInt16, UInt32, UInt64, UInt128,
idx = [1 2 3 4; 4 2 1 3; 6 7 8 9]
@test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx)
end

@testset "∇scatter" begin
T = Float64
@testset "∂dst" begin
for op in (+, -, max, min, *, /)
gradtest(xs -> scatter!(op, copy(xs), srcs[(0, true)], idxs[:int], Val(0)), T.(dsts[0]))
gradtest(xs -> scatter!(op, copy(xs), srcs[(1, true)], idxs[:int], Val(1)), T.(dsts[1]))
end
gradtest(xs -> scatter!(mean, copy(xs), srcs[(0, true)], idxs[:int]), T.(dsts[0]))
gradtest(xs -> scatter!(mean, copy(xs), srcs[(1, true)], idxs[:int]), T.(dsts[1]))
end

@testset "∂src" begin
for op in (+, -, max, min, *, /)
gradtest(xs -> scatter!(op, T.(dsts[0]), xs, idxs[:int], Val(0)), T.(srcs[(0, true)]))
gradtest(xs -> scatter!(op, T.(dsts[1]), xs, idxs[:int], Val(1)), T.(srcs[(1, true)]))

gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(0, false)]))
gradtest(xs -> scatter(op, xs, idxs[:int]), T.(srcs[(1, false)]))
end
gradtest(xs -> scatter!(mean, T.(dsts[0]), xs, idxs[:int]), T.(srcs[(0, true)]))
gradtest(xs -> scatter!(mean, T.(dsts[1]), xs, idxs[:int]), T.(srcs[(1, true)]))

gradtest(xs -> scatter(mean, xs, idxs[:int]), T.(srcs[(0, false)]))
gradtest(xs -> scatter(mean, xs, idxs[:int]), T.(srcs[(1, false)]))
end
end

0 comments on commit a4e9f88

Please sign in to comment.