Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use _pullback inside rules instead of pullback #1385

Merged
merged 2 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield
literal_getproperty, literal_getfield, unthunk_tangent

using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
Expand Down
44 changes: 19 additions & 25 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ end
enumerate(xs), back
end

@adjoint Iterators.Filter(f, x) = pullback(filter, f, collect(x))
function _pullback(cx::AContext, ::Type{<:Iterators.Filter}, f, x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note on why these are _pullback instead of @adjoint: @adjoint tries to do some clever postprocessing of the pullback return which results in mismatched results (differential arg tuple off by one). Using _pullback is slightly more code and requires a bit more care, but it removes a few layers from the call stack and brings us closer to removing @adjoint completely.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when Miha and I were trying to get rid of Zygote's own types we likewise removed some uses of @adjoint and wrote the _pullback overloads directly

res, back = _pullback(cx, filter, f, collect(x))
return res, back ∘ unthunk_tangent
end

_ndims(::Base.HasShape{d}) where {d} = d
_ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)) : 1
Expand Down Expand Up @@ -321,18 +324,12 @@ end
end
end

@adjoint function sum(f, xs::AbstractArray{<:AbstractArray}; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
end

@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
sum(xs, dims = dims), Δ -> (nothing,)
end

function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
y, back = pullback((f, xs) -> prod(f.(xs)), cx, f, xs)
y, ȳ -> (nothing, back(ȳ)...)
return _pullback(cx, (f, xs) -> prod(f.(xs)), f, xs)
end

@adjoint real(x::AbstractArray) = real(x), r̄ -> (real(r̄),)
Expand All @@ -357,8 +354,14 @@ function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
end
_kron(a::AbstractVector, b::AbstractVector) = vec(_kron(reshape(a, :, 1), reshape(b, :, 1)))

@adjoint kron(a::AbstractMatrix, b::AbstractMatrix) = pullback(_kron, a, b)
@adjoint kron(a::AbstractVector, b::AbstractVector) = pullback(_kron, a, b)
function _pullback(cx::AContext, ::typeof(kron), a::AbstractVector, b::AbstractVector)
res, back = _pullback(cx, _kron, a, b)
return res, back ∘ unthunk_tangent
end
function _pullback(cx::AContext, ::typeof(kron), a::AbstractMatrix, b::AbstractMatrix)
res, back = _pullback(cx, _kron, a, b)
return res, back ∘ unthunk_tangent
end

@adjoint logabsdet(xs::AbstractMatrix) = logabsdet(xs), Δ -> (Δ[1] * inv(xs)',)

Expand Down Expand Up @@ -432,15 +435,6 @@ end
@adjoint LinearAlgebra.UnitLowerTriangular(A) = UnitLowerTriangular(A), Δ->(UnitLowerTriangular(Δ)-I,)
@adjoint LinearAlgebra.UnitUpperTriangular(A) = UnitUpperTriangular(A), Δ->(UnitUpperTriangular(Δ)-I,)

# This is basically a hack while we don't have a working `ldiv!`.
@adjoint function \(A::Cholesky, B::AbstractVecOrMat)
Y, back = Zygote.pullback((U, B)->U \ (U' \ B), A.U, B)
return Y, function(Ȳ)
Ā_factors, B̄ = back(Ȳ)
return ((uplo=nothing, info=nothing, factors=Ā_factors), B̄)
end
end

function _symmetric_back(Δ, uplo)
L, U, D = LowerTriangular(Δ), UpperTriangular(Δ), Diagonal(Δ)
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
Expand Down Expand Up @@ -572,14 +566,14 @@ _hermsympow(A::Hermitian, p::Integer) = A^p

@adjoint function _hermsympow(A::Hermitian, p::Integer)
if p < 0
B, back = Zygote.pullback(A->Base.power_by_squaring(inv(A), -p), A)
B, back = _pullback(__context__, A -> Base.power_by_squaring(inv(A), -p), A)
else
B, back = Zygote.pullback(A->Base.power_by_squaring(A, p), A)
B, back = _pullback(__context__, A -> Base.power_by_squaring(A, p), A)
end
Ω = Hermitian(_realifydiag!(B))
return Ω, function (Ω̄)
B̄ = _hermitian_back(Ω̄, 'U')
Ā = back(B̄)[1]
Ā = last(back(B̄))
return (Ā, nothing)
end
end
Expand Down Expand Up @@ -812,8 +806,8 @@ end
# =======================

@adjoint function broadcasted(op, r::AbstractFill{<:Real})
y, _back = Zygote.pullback(op, getindex_value(r))
back(Δ::AbstractFill) = (nothing, Fill(_back(getindex_value(Δ))[1], size(r)))
back(Δ::AbstractArray) = (nothing, getindex.(_back.(Δ), 1))
y, _back = _pullback(__context__, op, getindex_value(r))
back(Δ::AbstractFill) = (nothing, Fill(last(_back(getindex_value(Δ))), size(r)))
back(Δ::AbstractArray) = (nothing, last.(_back.(Δ)))
return Fill(y, size(r)), back
end
7 changes: 3 additions & 4 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,9 @@ end
# For merge between NamedTuple and Dict, we will just convert the Dict to a NamedTuple.
# and then call `pullback`, which should overall be pretty efficient code generated,
# and it avoids trying to AD the problematic generic `merge(::NamedTuple, ::iter)` method which uses `push!`.
if VERSION >= v"1.6"
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, NamedTuple(dict))
else
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...))
function _pullback(cx::AContext, ::typeof(merge), a::NamedTuple, b::Dict{Symbol})
res, back = _pullback(cx, merge, a, NamedTuple(b))
return res, back ∘ unthunk_tangent
end

# Keyword arguments pretend to be a Dict, but are secretly wrapping a NamedTuple.
Expand Down
21 changes: 16 additions & 5 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
# Utilities
# =========

# ChainRules already marks this non-differentiable,
# But inference can still give up because of the Zygote -> CR wrapper layer
@nograd Broadcast.combine_styles
# ChainRules already marks this non-differentiable,# But inference can still give up because of the Zygote -> CR wrapper layer.
# This has been desugared from the (deprecated) @nograd macro.
@inline function Zygote._pullback(::AContext, ::typeof(Broadcast.combine_styles), args...)
dargs = ntuple(_ -> nothing, length(args) + 1)
combine_styles_pullback(_) = dargs
return Broadcast.combine_styles(args...), combine_styles_pullback
end

accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims)

Expand Down Expand Up @@ -358,9 +362,16 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve

# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::AbstractGPUArray; kws...)
function _pullback(cx::AContext, ::typeof(sum), f, xs::AbstractGPUArray)
res, back = _pullback(cx, (f, xs) -> sum(f.(xs)), f, xs)
return res, back ∘ unthunk_tangent
end
function _pullback(cx::AContext, ::Core.kwftype(typeof(sum)), kws, ::typeof(sum), f,
xs::AbstractGPUArray)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
res, back = _pullback(cx, (f, xs) -> sum(f.(xs); kws...), f, xs)
sum_gpuarray_kw_pullback(Δ) = (nothing, nothing, back(unthunk_tangent(Δ))...)
return res, sum_gpuarray_kw_pullback
end

@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray}
Expand Down
18 changes: 14 additions & 4 deletions src/lib/distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,32 @@ end

_sqrt_if_positive(d, δ) = d > δ ? sqrt(d) : zero(d)

@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean,
X::AbstractMatrix, Y::AbstractMatrix)
# Modify the forwards-pass slightly to ensure stability on the reverse.
dims = kws.dims
function _pairwise_euclidean(sqdist::SqEuclidean, X, Y)
D2 = pairwise(sqdist, X, Y; dims=dims)
δ = eps(eltype(D2))
return _sqrt_if_positive.(D2, δ)
end
return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X, Y)
res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X, Y)
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(unthunk_tangent(Δ))...)
return res, pairwise_Euclidean_pullback
end

@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix; dims=2)
function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean,
X::AbstractMatrix)
# Modify the forwards-pass slightly to ensure stability on the reverse.
dims = kws.dims
function _pairwise_euclidean(sqdist::SqEuclidean, X)
D2 = pairwise(sqdist, X; dims=dims)
δ = eps(eltype(D2))
return _sqrt_if_positive.(D2, δ)
end
return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X)
res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X)
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(unthunk_tangent(Δ))...)
return res, pairwise_Euclidean_pullback
end
20 changes: 12 additions & 8 deletions src/lib/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,32 +137,36 @@ end

# Use this to allow second derivatives -- this is forward-over-forward,
# see https://github.com/FluxML/Zygote.jl/issues/769 for a forward-over-reverse proposal
@adjoint function ForwardDiff.gradient(f, x)
function _pullback(cx::AContext, ::typeof(ForwardDiff.gradient), f, x)
F = typeof(f)
Base.issingletontype(F) || @warn """`ForwardDiff.gradient(f, x)` within Zygote cannot track gradients with respect to `f`,
and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
res, back = _pullback(cx, forwarddiff, x -> ForwardDiff.gradient(f, x), x)
return res, back ∘ unthunk_tangent
end

@adjoint function ForwardDiff.jacobian(f::F, x) where F
function _pullback(cx::AContext, ::typeof(ForwardDiff.jacobian), f::F, x) where F
Base.issingletontype(F) || @warn """`ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
res, back = _pullback(cx, forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
return res, back ∘ unthunk_tangent
end

@adjoint function ForwardDiff.derivative(f::F, x) where F
function _pullback(cx::AContext, ::typeof(ForwardDiff.derivative), f::F, x) where F
Base.issingletontype(F) || @warn """`ForwardDiff.derivative(f, x)` within Zygote cannot track gradients with respect to `f`,
and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.derivative(f, x), x)
res, back = _pullback(cx, forwarddiff, x -> ForwardDiff.derivative(f, x), x)
return res, back ∘ unthunk_tangent
end

@adjoint function ForwardDiff.hessian(f::F, x) where F
function _pullback(cx::AContext, ::typeof(ForwardDiff.hessian), f::F, x) where F
Base.issingletontype(F) || @warn """`ForwardDiff.hessian(f, x)` within Zygote cannot track gradients with respect to `f`,
and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.hessian(f, x), x)
res, back = _pullback(cx, forwarddiff, x -> ForwardDiff.hessian(f, x), x)
return res, back ∘ unthunk_tangent
end

1 change: 0 additions & 1 deletion test/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using CUDA
using Zygote: Grads
using LinearAlgebra
using Random: randn!
import FiniteDifferences
CUDA.allowscalar(false)
Expand Down
2 changes: 0 additions & 2 deletions test/forward/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ end == 0
@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1+0im)[1]
@test real(D(x -> abs(x+2im), 1)) == gradient(x -> abs(x+2im), 1)[1] # ProjectTo means gradient here is real

using LinearAlgebra

@test D(3) do x
A = zeros(5, 5)
B = zeros(5, 5)
Expand Down
1 change: 0 additions & 1 deletion test/lib/array.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using ChainRulesTestUtils
using LinearAlgebra
using Zygote: ZygoteRuleConfig, _pullback

# issue 897
Expand Down
2 changes: 0 additions & 2 deletions test/lib/base.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using LinearAlgebra;

@testset "base.jl" begin
@testset "Dict getindex with implicit params" begin
d = Dict{String, Vector{Float64}}("key"=>ones(4))
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Zygote, Test
using Zygote, Test, LinearAlgebra
using Zygote: gradient, ZygoteRuleConfig
using CUDA
using CUDA: has_cuda
Expand Down
1 change: 0 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using LinearAlgebra
using ForwardDiff
using Zygote: hessian_dual, hessian_reverse

Expand Down