Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster generic broadcasting #1001

Merged
merged 18 commits into from
Jun 25, 2021
1 change: 1 addition & 0 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ end

struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
(::StaticGetter{i})(::Nothing) where {i} = nothing
@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...)
end
Expand Down
42 changes: 32 additions & 10 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,31 +164,53 @@ end
# Avoid hitting special cases for `Adjoint` etc.
_broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))

_get(x::Tuple, i) = x[i]
_get(::Nothing, i) = nothing
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
collapse_nothings(xs) = xs

@adjoint function broadcasted(::AbstractArrayStyle, f, args...)
_purefun(::Type{F}) where {F<:Function} = isempty(fieldnames(F))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
_purefun(::Type) = false
if VERSION >= v"1.6"
_purefun(::Type{ComposedFunction{F,G}}) where {F,G} = _purefun(F) && _purefun(G)
end
_purefun(::Type{typeof(^)}) = false # fix @testset "power" & @testset "diagonal hessian"
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

_dualsafe(x::Numeric{<:Real}) = true
_dualsafe(x::Ref{<:Numeric{<:Real}}) = true
_dualsafe(x::Val) = true
_dualsafe(x::Type) = true
_dualsafe(x::Symbol) = true
_dualsafe(x) = false

@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
return f.(args...), _->nothing
elseif T <: Real && isconcretetype(T) && _purefun(F) && all(_dualsafe, args)
y, back = broadcast_forward(f, args...)
return y, ȳ -> (nothing, nothing, back(ȳ)...)
end
len = inclen(args)
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y = map(x -> x[1], y∂b)
∂b = map(x -> x[2], y∂b)
y, function (ȳ)
dxs_zip = map((∂b, ȳ) -> ∂b(ȳ), ∂b, ȳ)
dxs = collapse_nothings.(ntuple(i -> map(x -> _get(x, i), dxs_zip), len))
y = map(first, y∂b)
function ∇broadcasted(ȳ)
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
dxs = ntuple(len) do i
collapse_nothings(map(StaticGetter{i}(), dxs_zip))
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
end
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
Comment on lines +193 to 196
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clearly map(StaticGetter{i}(), dxs_zip) should really be fused with unbroadcast, possibly into mapreduce.

end
y, ∇broadcasted
end

@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...)
len = inclen(args)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
y, ∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y, function (ȳ)
function ∇broadcasted0(ȳ)
dxs = ∂b(ȳ)
dxs === nothing && return nothing
(nothing, dxs...)
end
y, ∇broadcasted0
end

# Use the `map` adjoint in this special case, which is the same but applies
Expand Down
40 changes: 30 additions & 10 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -500,14 +500,34 @@ end
@test 150_000_000 > @allocated gradient(loss, ones(1000,1000))
end

@testset "tuples & broadcasting" begin
@test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
@test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
@test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)

# https://github.com/FluxML/Zygote.jl/issues/975
gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
@test gt[1] == gv[1]
@test collect(gt[2]) ≈ gv[2]
@testset "tricky broadcasting" begin
@test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
@test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
@test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)

# https://github.com/FluxML/Zygote.jl/issues/975
gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
@test gt[1] == gv[1]
@test collect(gt[2]) ≈ gv[2]

# closure captures y -- can't use ForwardDiff
@test gradient((x,y) -> sum((z->z^2+y[1]).(x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
@test gradient((x,y) -> sum((z->z^2+y[1]), x), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
@test gradient((x,y) -> sum(map((z->z^2+y[1]), x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
@test gradient((x,y) -> mapreduce((z->z^2+y[1]), +, x), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])

# type unstable
@test gradient(xs -> sum((x -> x<2 ? false : x^2).(xs)), [1,2,3])[1][2:3] == [4, 6]
@test gradient(xs -> sum((x -> x<2 ? false : x^2), xs), [1,2,3])[1][2:3] == [4, 6]
@test gradient(xs -> sum(map((x -> x<2 ? false : x^2), xs)), [1,2,3])[1][2:3] == [4, 6]
@test gradient(xs -> mapreduce((x -> x<2 ? false : x^2), +, xs), [1,2,3])[1][2:3] == [4, 6]

# with Ref, Val, Symbol
@test gradient(x -> sum(x .+ Ref(x[1])), [1,2,3]) == ([4,1,1],)
@test gradient(x -> sum(x .+ (x[1],)), [1,2,3]) == ([4,1,1],)
@test gradient(x -> sum((first∘tuple).(x, :ignore)), [1,2,3]) == ([1,1,1],)
@test gradient(x -> sum((first∘tuple).(x, Symbol)), [1,2,3]) == ([1,1,1],)
_f(x,::Val{y}) where {y} = x/y
@test gradient(x -> sum(_f.(x, Val(2))), [1,2,3]) == ([0.5, 0.5, 0.5],)
end
3 changes: 2 additions & 1 deletion test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,8 @@ end
end

@testset "broadcast" begin
@test gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1
# Before https://github.com/FluxML/Zygote.jl/pull/1001 this gave [1 1 1; 1 0 1; 1 1 -1]
@test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1]

a = rand(3)
b = rand(2,2)
Expand Down