Skip to content

Commit

Permalink
Add many frules (#565)
Browse files Browse the repository at this point in the history
* drop 1.0, now that LTS == 1.6

* revert to one Project

* rm Compat

* turns out this does still need Compat

* add many frules

* in-place frules

* reshape + dropdims too

* tests

* 5-arg mul

* notation changes

* rm 2nd order rules

* don't skip setindex

* AbstractArray constructors

* reshape tests

* Apply 4 suggestions

Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>

* fixup, bump

* several comments, and one rule for PermutedDimsArray

* in fact sortslices is fine with offsets

Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
  • Loading branch information
mcabbott and oxinabox authored Jan 25, 2022
1 parent 13a362c commit 8c34f19
Show file tree
Hide file tree
Showing 13 changed files with 418 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.23"
version = "1.24"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
169 changes: 140 additions & 29 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,37 @@

ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...)

function frule((_, ẋ), ::Type{T}, x::AbstractArray) where {T<:Array}
return T(x), T(ẋ)
end

function frule((_, ẋ), ::Type{AbstractArray{T}}, x::AbstractArray) where {T}
return AbstractArray{T}(x), AbstractArray{T}(ẋ)
end

function rrule(::Type{T}, x::AbstractArray) where {T<:Array}
project_x = ProjectTo(x)
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return T(x), Array_pullback
end

# This abstract one is used for `float(x)` and other float conversion purposes:
function rrule(::Type{AbstractArray{T}}, x::AbstractArray) where {T}
project_x = ProjectTo(x)
AbstractArray_pullback(ȳ) = (NoTangent(), project_x(ȳ))
return AbstractArray{T}(x), AbstractArray_pullback
end

#####
##### `vect`
#####

@non_differentiable Base.vect()

function frule((_, ẋs...), ::typeof(Base.vect), xs::Number...)
return Base.vect(xs...), Base.vect(_instantiate_zeros(ẋs, xs)...)
end

# Case of uniform type `T`: the data passes straight through,
# so no projection should be required.
function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
Expand Down Expand Up @@ -43,32 +62,84 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N}
return Base.vect(X...), vect_pullback
end

"""
_instantiate_zeros(ẋs, xs)
Forward rules for `vect`, `cat` etc may receive a mixture of data and `ZeroTangent`s.
To avoid `vect(1, ZeroTangent(), 3)` or worse `vcat([1,2], ZeroTangent(), [6,7])`, this
materialises each zero `ẋ` to be `zero(x)`.
"""
_instantiate_zeros(ẋs, xs) = map(_i_zero, ẋs, xs)
_i_zero(ẋ, x) =
_i_zero(ẋ::AbstractZero, x) = zero(x)
# Possibly this won't work for partly non-diff arrays, sometihng like `gradient(x -> ["abc", x][end], 1)`
# may give a MethodError for `zero` but won't be wrong.

# Fast paths. Should it also collapse all-Zero cases?
_instantiate_zeros(ẋs::Tuple{Vararg{<:Number}}, xs) = ẋs
_instantiate_zeros(ẋs::Tuple{Vararg{<:AbstractArray}}, xs) = ẋs
_instantiate_zeros(ẋs::AbstractArray{<:Number}, xs) = ẋs
_instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs

#####
##### `copyto!`
#####

function frule((_, ẏ, ẋ), ::typeof(copyto!), y::AbstractArray, x)
return copyto!(y, x), copyto!(ẏ, ẋ)
end

function frule((_, ẏ, _, ẋ), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...)
return copyto!(y, i, x, js...), copyto!(ẏ, i, ẋ, js...)
end

#####
##### `reshape`
#####

function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Union{Colon,Int}}})
A_dims = size(A)
function reshape_pullback(Ȳ)
return (NoTangent(), reshape(Ȳ, A_dims), NoTangent())
end
return reshape(A, dims), reshape_pullback
function frule((_, ẋ), ::typeof(reshape), x::AbstractArray, dims...)
return reshape(x, dims...), reshape(ẋ, dims...)
end

function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...)
A_dims = size(A)
function reshape_pullback(Ȳ)
∂A = reshape(Ȳ, A_dims)
∂dims = broadcast(Returns(NoTangent()), dims)
return (NoTangent(), ∂A, ∂dims...)
end
function rrule(::typeof(reshape), A::AbstractArray, dims...)
ax = axes(A)
project = ProjectTo(A) # Projection is here for e.g. reshape(::Diagonal, :)
∂dims = broadcast(Returns(NoTangent()), dims)
reshape_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax)), ∂dims...)
return reshape(A, dims...), reshape_pullback
end

#####
##### `dropdims`
#####

function frule((_, ẋ), ::typeof(dropdims), x::AbstractArray; dims)
return dropdims(x; dims), dropdims(ẋ; dims)
end

function rrule(::typeof(dropdims), A::AbstractArray; dims)
ax = axes(A)
project = ProjectTo(A)
dropdims_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax)))
return dropdims(A; dims), dropdims_pullback
end

#####
##### `permutedims`
#####

function frule((_, ẋ), ::typeof(permutedims), x::AbstractArray, perm...)
return permutedims(x, perm...), permutedims(ẋ, perm...)
end

function frule((_, ẏ, ẋ), ::typeof(permutedims!), y::AbstractArray, x::AbstractArray, perm...)
return permutedims!(y, x, perm...), permutedims!(ẏ, ẋ, perm...)
end

function frule((_, ẋ), ::Type{<:PermutedDimsArray}, x::AbstractArray, perm)
return PermutedDimsArray(x, perm), PermutedDimsArray(ẋ, perm)
end

function rrule(::typeof(permutedims), x::AbstractVector)
project = ProjectTo(x)
permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy))))
Expand All @@ -91,6 +162,10 @@ end
##### `repeat`
#####

function frule((_, ẋs), ::typeof(repeat), xs::AbstractArray, cnt...; kw...)
return repeat(xs, cnt...; kw...), repeat(ẋs, cnt...; kw...)
end

function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs)))

project_Xs = ProjectTo(xs)
Expand Down Expand Up @@ -130,6 +205,10 @@ end
##### `hcat`
#####

function frule((_, ẋs...), ::typeof(hcat), xs...)
return hcat(xs...), hcat(_instantiate_zeros(ẋs, xs)...)
end

function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray
ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability
Expand Down Expand Up @@ -164,6 +243,10 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
return Y, hcat_pullback
end

function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat})
return reduce(hcat, As), reduce(hcat, _instantiate_zeros(Ȧs, As))
end

function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat})
widths = map(A -> size(A,2), As)
function reduce_hcat_pullback_2(dY)
Expand Down Expand Up @@ -192,6 +275,10 @@ end
##### `vcat`
#####

function frule((_, ẋs...), ::typeof(vcat), xs...)
return vcat(xs...), vcat(_instantiate_zeros(ẋs, xs)...)
end

function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
Y = vcat(Xs...)
ndimsY = Val(ndims(Y))
Expand Down Expand Up @@ -224,6 +311,10 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
return Y, vcat_pullback
end

function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat})
return reduce(vcat, As), reduce(vcat, _instantiate_zeros(Ȧs, As))
end

function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat})
Y = reduce(vcat, As)
ndimsY = Val(ndims(Y))
Expand All @@ -247,6 +338,10 @@ end

_val(::Val{x}) where {x} = x

function frule((_, ẋs...), ::typeof(cat), xs...; dims)
return cat(xs...; dims), cat(_instantiate_zeros(ẋs, xs)...; dims)
end

function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims)
Y = cat(Xs...; dims=dims)
cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims)
Expand Down Expand Up @@ -285,6 +380,10 @@ end
##### `hvcat`
#####

function frule((_, _, ẋs...), ::typeof(hvcat), rows, xs...)
return hvcat(rows, xs...), hvcat(rows, _instantiate_zeros(ẋs, xs)...)
end

function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...)
Y = hvcat(rows, values...)
cols = size(Y,2)
Expand Down Expand Up @@ -321,8 +420,12 @@ end
# 1-dim case allows start/stop, N-dim case takes dims keyword
# whose defaults changed in Julia 1.6... just pass them all through:

function frule((_, xdot), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
return reverse(x, args...; kw...), reverse(xdot, args...; kw...)
function frule((_, ẋ), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
return reverse(x, args...; kw...), reverse(ẋ, args...; kw...)
end

function frule((_, ẋ), ::typeof(reverse!), x::Union{AbstractArray, Tuple}, args...; kw...)
return reverse!(x, args...; kw...), reverse!(ẋ, args...; kw...)
end

function rrule(::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
Expand All @@ -338,8 +441,12 @@ end
##### `circshift`
#####

function frule((_, xdot), ::typeof(circshift), x::AbstractArray, shifts)
return circshift(x, shifts), circshift(xdot, shifts)
function frule((_, ẋ), ::typeof(circshift), x::AbstractArray, shifts)
return circshift(x, shifts), circshift(ẋ, shifts)
end

function frule((_, ẏ, ẋ), ::typeof(circshift!), y::AbstractArray, x::AbstractArray, shifts)
return circshift!(y, x, shifts), circshift!(ẏ, ẋ, shifts)
end

function rrule(::typeof(circshift), x::AbstractArray, shifts)
Expand All @@ -355,8 +462,12 @@ end
##### `fill`
#####

function frule((_, xdot), ::typeof(fill), x::Any, dims...)
return fill(x, dims...), fill(xdot, dims...)
function frule((_, ẋ), ::typeof(fill), x::Any, dims...)
return fill(x, dims...), fill(ẋ, dims...)
end

function frule((_, ẏ, ẋ), ::typeof(fill!), y::AbstractArray, x::Any)
return fill!(y, x), fill!(ẏ, ẋ)
end

function rrule(::typeof(fill), x::Any, dims...)
Expand All @@ -370,9 +481,9 @@ end
##### `filter`
#####

function frule((_, _, xdot), ::typeof(filter), f, x::AbstractArray)
function frule((_, _, ), ::typeof(filter), f, x::AbstractArray)
inds = findall(f, x)
return x[inds], xdot[inds]
return x[inds], [inds]
end

function rrule(::typeof(filter), f, x::AbstractArray)
Expand All @@ -392,9 +503,9 @@ end
for findm in (:findmin, :findmax)
findm_pullback = Symbol(findm, :_pullback)

@eval function frule((_, xdot), ::typeof($findm), x; dims=:)
@eval function frule((_, ), ::typeof($findm), x; dims=:)
y, ind = $findm(x; dims=dims)
return (y, ind), Tangent{typeof((y, ind))}(xdot[ind], NoTangent())
return (y, ind), Tangent{typeof((y, ind))}([ind], NoTangent())
end

@eval function rrule(::typeof($findm), x::AbstractArray; dims=:)
Expand Down Expand Up @@ -441,8 +552,8 @@ end
# Allow for second derivatives, by writing rules for `_zerolike_writeat`;
# these rules are the reason it takes a `dims` argument.

function frule((_, _, dydot), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dydot, dims, inds...)
function frule((_, _, dẏ), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dẏ, dims, inds...)
end

function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...)
Expand All @@ -457,9 +568,9 @@ end

# These rules for `maximum` pick the same subgradient as `findmax`:

function frule((_, xdot), ::typeof(maximum), x; dims=:)
function frule((_, ), ::typeof(maximum), x; dims=:)
y, ind = findmax(x; dims=dims)
return y, xdot[ind]
return y, [ind]
end

function rrule(::typeof(maximum), x::AbstractArray; dims=:)
Expand All @@ -468,9 +579,9 @@ function rrule(::typeof(maximum), x::AbstractArray; dims=:)
return y, maximum_pullback
end

function frule((_, xdot), ::typeof(minimum), x; dims=:)
function frule((_, ), ::typeof(minimum), x; dims=:)
y, ind = findmin(x; dims=dims)
return y, xdot[ind]
return y, [ind]
end

function rrule(::typeof(minimum), x::AbstractArray; dims=:)
Expand Down
24 changes: 23 additions & 1 deletion src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ end
##### `*`
#####

frule((_, ΔA, ΔB), ::typeof(*), A, B) = A * B, muladd(ΔA, B, A * ΔB)

frule((_, ΔA, ΔB, ΔC), ::typeof(*), A, B, C) = A*B*C, ΔA*B*C + A*ΔB*C + A*B*ΔC


function rrule(
::typeof(*),
Expand Down Expand Up @@ -88,7 +92,9 @@ function rrule(
end



#####
##### `*` matrix-scalar_rule
#####

function rrule(
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}
Expand Down Expand Up @@ -204,6 +210,11 @@ end # VERSION
##### `muladd`
#####

function frule((_, ΔA, ΔB, Δz), ::typeof(muladd), A, B, z)
Ω = muladd(A, B, z)
return Ω, ΔA * B .+ A * ΔB .+ Δz
end

function rrule(
::typeof(muladd),
A::AbstractMatrix{<:CommutativeMulNumber},
Expand Down Expand Up @@ -351,6 +362,13 @@ end
##### `\`, `/` matrix-scalar_rule
#####

function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber)
return A/b, ΔA/b - A*(Δb/b^2)
end
function frule((_, Δa, ΔB), ::typeof(\), a::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber})
return B/a, ΔB/a - B*(Δa/a^2)
end

function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber)
Y = A/b
function slash_pullback_scalar(ȳ)
Expand Down Expand Up @@ -378,6 +396,8 @@ end
##### Negation (Unary -)
#####

frule((_, ΔA), ::typeof(-), A::AbstractArray) = -A, -ΔA

function rrule(::typeof(-), x::AbstractArray)
function negation_pullback(ȳ)
return NoTangent(), InplaceableThunk(ā ->.-= ȳ, @thunk(-ȳ))
Expand All @@ -390,6 +410,8 @@ end
##### Addition (Multiarg `+`)
#####

frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...)

function rrule(::typeof(+), arrs::AbstractArray...)
y = +(arrs...)
arr_axs = map(axes, arrs)
Expand Down
Loading

2 comments on commit 8c34f19

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/53185

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.24.0 -m "<description of version>" 8c34f19d3a8a8a224c9fbe20524d2a08c8b9bf81
git push origin v1.24.0

Please sign in to comment.