Skip to content

RFC: broadcasting #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[compat]
ChainRules = "1.5"
ChainRulesCore = "1.2"
ChainRules = "1.17"
ChainRulesCore = "1.11"
Combinatorics = "1"
StaticArrays = "1"
StatsBase = "0.33"
Expand Down
42 changes: 16 additions & 26 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ struct NonDiffOdd{N, O, P}; end
# This should not happen
(::NonDiffEven{N, O, O})(Δ...) where {N, O} = error()

# WARNING: Method definition rrule(typeof(Core.apply_type), Any, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Core/core.jl:10 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:140.
@Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...)
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}()
end
Expand All @@ -145,17 +146,8 @@ function ChainRulesCore.rrule(::typeof(Core.tuple), args...)
Core.tuple(args...), Δ->Core.tuple(NoTangent(), Δ...)
end

# TODO: What to do about these integer rules
@ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type)

ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()

# Skip AD'ing through the axis computation
function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
return Base.Broadcast.instantiate(bc), Δ->begin
Core.tuple(NoTangent(), Δ)
end
end
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()


using StaticArrays
Expand Down Expand Up @@ -199,20 +191,6 @@ function ChainRules.rrule(::typeof(map), ::typeof(+), A::AbstractVector, B::Abst
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
end

function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N}
# We're leaving these in the eltype that the cotangent vector already has.
# There isn't really a good reason to believe we should convert to the
# original array type, so don't unless explicitly requested.
AT(x), Δ->(NoTangent(), Δ)
end

function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...)
# We're leaving these in the eltype that the cotangent vector already has.
# There isn't really a good reason to believe we should convert to the
# original array type, so don't unless explicitly requested.
AT(undef, args...), Δ->(NoTangent(), NoTangent(), ntuple(_->NoTangent(), length(args))...)
end

function unzip_tuple(t::Tuple)
map(x->x[1], t), map(x->x[2], t)
end
Expand Down Expand Up @@ -252,10 +230,8 @@ function ChainRules.frule(_, ::Type{Vector{T}}, undef::UndefInitializer, dims::I
Vector{T}(undef, dims...), zeros(T, dims...)
end

@ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer)
@ChainRules.non_differentiable Base.throw(err)
@ChainRules.non_differentiable Core.Compiler.return_type(args...)
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()

# Disable thunking at higher order (TODO: These should go into ChainRulesCore)
function ChainRulesCore.rrule(::Type{Thunk}, thnk)
Expand All @@ -266,3 +242,17 @@ end
function ChainRulesCore.rrule(::Type{InplaceableThunk}, add!!, val)
val, Δ->(NoTangent(), NoTangent(), Δ)
end

# ERROR: ArgumentError: Tangent for the primal Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}} should be backed by a AbstractDict type, not by NamedTuple{(:data,), Tuple{ChainRulesCore.ZeroTangent}}.
ChainRulesCore._backing_error(::Type{<:Base.Pairs{Symbol}}, ::Type{<:NamedTuple}, _) = nothing # solves that!

# Rather than have a rule for broadcasted 3-arg *, just send it to the efficient path:
ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number) = ((y*z, x*z, x*y),)
function ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number, w::Number)
xy = x*y
zw = z*w
((y*zw, x*zw, xy*w, xy*z),)
end

# Fixes @test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;]
(project::ProjectTo{<:AbstractArray})(th::InplaceableThunk) = project(unthunk(th))
197 changes: 170 additions & 27 deletions src/stage1/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,45 +29,188 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)},
return r
end

# Reverse mode broadcast rules

using ChainRulesCore: derivatives_given_output

# Broadcast over one element is just map
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
∂⃖ₙ(map, f, a)
# function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
# ∂⃖ₙ(map, f, a)
# end

(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ)

(::∂⃖{1})(::typeof(broadcasted), f::F, args...) where {F} = split_bc_rule(f, args...)
# (::∂⃖{1})(::typeof(broadcasted), f::F, arg::Array) where {F} = split_bc_rule(f, arg) # ambiguity
function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
T = Broadcast.combine_eltypes(f, args)
TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...})
if T === Bool
# Trivial case: non-differentiable output, e.g. `x .> 0`
back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2)
return f.(args...), back_1
elseif T <: Number && isconcretetype(TΔ)
# Fast path: just broadcast, and use arguments & result to find derivatives.
ys = f.(args...)
function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all
delta = broadcast(unthunk(dys), ys, args...) do dy, y, a
das = only(derivatives_given_output(y, f, a))
dy * conj(only(das)) # possibly this * should be made nan-safe.
end
(NoTangent(), NoTangent(), unbroadcast(only(args), delta))
end
function back_2_many(dys)
deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as...
das = only(derivatives_given_output(y, f, as...))
map(da -> dy * conj(da), das)
end
dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of splitcast?
(NoTangent(), NoTangent(), dargs...)
end
return ys, N==1 ? back_2_one : back_2_many
else
# Slow path: collect all the pullbacks & apply them later.
# (Since broadcast makes no guarantee about order of calls, and un-fusing
# can change the number of calls, this does not bother to try to reverse.)
ys3, backs = tuplecast(∂⃖{1}(), f, args...)
function back_3(dys)
deltas = tuplecast(backs, unthunk(dys)) do back, dy # could be map, sizes match
map(unthunk, back(dy))
end
dargs = map(unbroadcast, args, Base.tail(deltas))
(NoTangent(), sum(first(deltas)), dargs...)
end
back_3(::AbstractZero) = (NoTangent(), map(Returns(ZeroTangent()), args)...)
return ys3, back_3
end
end

# The below is from Zygote: TODO: DO we want to do something better here?
# Don't run broadcasting on scalars
function split_bc_rule(f::F, args::Number...) where {F}
z, back = ∂⃖{1}()(f, args...)
z, dz -> (NoTangent(), back(dz)...)
end

accum_sum(xs::Nothing; dims = :) = NoTangent()
accum_sum(xs::AbstractArray{Nothing}; dims = :) = NoTangent()
accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims)
accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims)
accum_sum(xs::Number; dims = :) = xs
split_bc_rule(::typeof(identity), x) = x, Δ -> (NoTangent(), NoTangent(), Δ)
split_bc_rule(::typeof(identity), x::Number) = x, Δ -> (NoTangent(), NoTangent(), Δ)

# https://github.com/FluxML/Zygote.jl/issues/594
function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArray, region)
Base.reducedim_initarray(A, region, NoTangent(), Union{Nothing,eltype(A)})
# Skip AD'ing through the axis computation
function (::∂⃖{1})(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
uninstantiate(Δ) = Core.tuple(NoTangent(), Δ)
return Base.Broadcast.instantiate(bc), uninstantiate
end

trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
using StructArrays

function tuplecast(f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
if isconcretetype(T)
T <: Tuple || throw(ArgumentError("tuplecast(f, args) only works on functions returning a tuple."))
end
bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
StructArrays.components(StructArray(bc))
end

unbroadcast(x::AbstractArray, x̄) =
size(x) == size(x̄) ? x̄ :
length(x) == length(x̄) ? trim(x, x̄) :
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
# For certain cheap operations we can easily allow fused broadcast:
const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted}

unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args::NumericOrBroadcast...) = lazy_bc_plus(args...)
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args::Number) = split_bc_rule(+, args...)
function lazy_bc_plus(xs...) where {F}
broadcasted(+, xs...), Δraw -> let Δ = unthunk(Δraw)
(NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δ), xs)...)
end
end

unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent()
(::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = split_bc_rule(-, x, y)
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast)
broadcasted(-, x, y), Δraw -> let Δ = unthunk(Δraw)
(NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ))
end
end

const Numeric = Union{Number, AbstractArray{<:Number, N} where N}
using LinearAlgebra: dot

function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...)
broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
(::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = split_bc_rule(*, x, y)
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast)
broadcasted(*, x, y), Δraw -> let Δ = unthunk(Δraw)
(NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ))
end
end
_back_star(x, y, Δ) = unbroadcast(x, Δ .* conj.(y))
_back_star(x::Number, y, Δ) = dot(y, Δ)
_back_star(x::Bool, y, Δ) = NoTangent()

ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2})
broadcasted(*, x, x), Δ -> begin
dx = unbroadcast(x, 2 .* unthunk(Δ) .* conj.(x))
(NoTangent(), NoTangent(), NoTangent(), dx, NoTangent())
end
end
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2})
x^2, Δ -> (NoTangent(), NoTangent(), NoTangent(), 2 * Δ * conj(x), NoTangent())
end

(::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = split_bc_rule(/, x, y)
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number)
z = broadcast(/, x, y)
z, Δth -> let Δ = unthunk(Δth)
dx = unbroadcast(x, Δ ./ conj.(y))
dy = -dot(z, Δ) / (conj(y)) # the reason to be eager is to allow dot here
(NoTangent(), NoTangent(), dx, dy)
end
end

(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) = split_bc_rule(identity, x)
# (::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x::Array) = split_bc_rule(identity, x) # ambiguity

(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = split_bc_rule(identity, x)
# (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array{Real}) = split_bc_rule(identity, x) # ambiguity
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x) =
broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ)))
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array) =
broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ)))

# Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape:
function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
N = ndims(dx)
if length(x) == length(dx)
ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
else
dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # hack to get type-stable `dims`
ProjectTo(x)(sum(dx; dims)) # ideally this sum might be thunked?
end
end
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx

unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx)))
function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
val = if length(x) == length(dx)
dx
else
sum(dx; dims=2:ndims(dx))
end
ProjectTo(x)(NTuple{length(x)}(val)) # Tangent
end

ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end
unbroadcast(f::Function, df) = sum(df)
unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx))
unbroadcast(x::Base.RefValue, dx) = ProjectTo(x)(Ref(sum(dx)))

unbroadcast(::Bool, dx) = NoTangent()
unbroadcast(::AbstractArray{Bool}, dx) = NoTangent()
unbroadcast(::AbstractArray{Bool}, ::NoTangent) = NoTangent() # ambiguity
unbroadcast(::Val, dx) = NoTangent()

function unbroadcast(x, dx)
p = ProjectTo(x)
if dx isa AbstractZero || p isa ProjectTo{<:AbstractZero}
return NoTangent()
end
b = Broadcast.broadcastable(x)
if b isa Ref # then x is scalar under broadcast
return p(sum(dx))
else
error("don't know how to handle broadcast gradient for x::$(typeof(x))")
end
end
2 changes: 2 additions & 0 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ end
Base.getindex(o::OpticBundle, i::Int) = i == 1 ? o.x :
i == 2 ? o.clos :
throw(BoundsError(o, i))
Base.lastindex(o::OpticBundle) = 2

Base.iterate(o::OpticBundle) = (o.x, nothing)
Base.iterate(o::OpticBundle, ::Nothing) = (o.clos, missing)
Base.iterate(o::OpticBundle, ::Missing) = nothing
Expand Down
Loading