Skip to content

Commit

Permalink
Merge pull request #427 from mcabbott/projecttangents
Browse files Browse the repository at this point in the history
Fix #426 -- gradient of Ref is a Tangent
  • Loading branch information
mcabbott authored Aug 19, 2021
2 parents 2208660 + eb872cd commit 8218c2c
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 56 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.3.0"
version = "1.3.1"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
4 changes: 4 additions & 0 deletions src/differentials/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Base.transpose(z::AbstractZero) = z
Base.:/(z::AbstractZero, ::Any) = z

Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T)
(::Type{T})(xs::AbstractZero...) where T <: Number = zero(T)

(::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y)
(::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false)

Base.getindex(z::AbstractZero, k) = z

Expand Down
94 changes: 49 additions & 45 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ function generic_projector(x::T; kw...) where {T}
fields_nt::NamedTuple = backing(x)
fields_proj = map(_maybe_projector, fields_nt)
# We can't use `T` because if we have `Foo{Matrix{E}}` it should be allowed to make a
# `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
# but if it doesn't `construct` will give a good error message.
# `Foo{Diagaonal{E}}` etc. Official API for this? https://github.com/JuliaLang/julia/issues/35543
wrapT = T.name.wrapper
# Official API for this? https://github.com/JuliaLang/julia/issues/35543
return ProjectTo{wrapT}(; fields_proj..., kw...)
end

Expand All @@ -72,12 +70,6 @@ function generic_projection(project::ProjectTo{T}, dx::T) where {T}
return construct(T, map(_maybe_call, sub_projects, sub_dxs))
end

function (project::ProjectTo{T})(dx::Tangent) where {T}
sub_projects = backing(project)
sub_dxs = backing(canonicalize(dx))
return construct(T, map(_maybe_call, sub_projects, sub_dxs))
end

# Used for encoding fields, leaves alone non-diff types:
_maybe_projector(x::Union{AbstractArray,Number,Ref}) = ProjectTo(x)
_maybe_projector(x) = x
Expand Down Expand Up @@ -123,7 +115,6 @@ ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2),
ProjectTo(::Any) # just to attach docstring

# Generic
(::ProjectTo{T})(dx::T) where {T} = dx # not always correct but we have special cases for when it isn't
(::ProjectTo{T})(dx::AbstractZero) where {T} = dx
(::ProjectTo{T})(dx::NotImplemented) where {T} = dx

Expand All @@ -133,7 +124,17 @@ ProjectTo(::Any) # just to attach docstring
# Zero
ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pass makes this one projector,
(::ProjectTo{NoTangent})(dx) = NoTangent() # but this is the projection only for nonzero gradients,
(::ProjectTo{NoTangent})(::NoTangent) = NoTangent() # and this one solves an ambiguity.
(::ProjectTo{NoTangent})(dx::AbstractZero) = dx # and this one solves an ambiguity.

# Also, any explicit construction with fields, where all fields project to zero, itself
# projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
const _PZ = ProjectTo{<:AbstractZero}
ProjectTo{P}(::NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}}) where {P,T} = ProjectTo{NoTangent}()

# Tangent
# We haven't entirely figured out when to convert Tangents to "natural" representations such as
# dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
(::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx

#####
##### `Base`
Expand Down Expand Up @@ -165,27 +166,29 @@ end
(::ProjectTo{T})(dx::Integer) where {T<:Complex{<:AbstractFloat}} = convert(T, dx)

# Other numbers, including e.g. ForwardDiff.Dual and Symbolics.Sym, should pass through.
# We assume (lacking evidence to the contrary) that it is the right subspace of numebers
# The (::ProjectTo{T})(::T) method doesn't work because we are allowing a different
# Number type that might not be a subtype of the `project_type`.
# We assume (lacking evidence to the contrary) that it is the right subspace of numebers.
(::ProjectTo{<:Number})(dx::Number) = dx

(project::ProjectTo{<:Real})(dx::Complex) = project(real(dx))
(project::ProjectTo{<:Complex})(dx::Real) = project(complex(dx))

# Tangents: we prefer to reconstruct numbers, but only safe to try when their constructor
# understands, including a mix of Zeros & reals. Other cases, we just let through:
(project::ProjectTo{<:Complex})(dx::Tangent{<:Complex}) = project(Complex(dx.re, dx.im))
(::ProjectTo{<:Number})(dx::Tangent{<:Number}) = dx

# Arrays
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
# no structure worth re-imposing. Then any array is acceptable as a gradient.

# For arrays of numbers, just store one projector:
function ProjectTo(x::AbstractArray{T}) where {T<:Number}
element = T <: Irrational ? ProjectTo{Real}() : ProjectTo(zero(T))
if element isa ProjectTo{<:AbstractZero}
return ProjectTo{NoTangent}() # short-circuit if all elements project to zero
else
return ProjectTo{AbstractArray}(; element=element, axes=axes(x))
end
return ProjectTo{AbstractArray}(; element=_eltype_projectto(T), axes=axes(x))
end
ProjectTo(x::AbstractArray{Bool}) = ProjectTo{NoTangent}()

_eltype_projectto(::Type{T}) where {T<:Number} = ProjectTo(zero(T))
_eltype_projectto(::Type{<:Irrational}) = ProjectTo{Real}()

# In other cases, store a projector per element:
function ProjectTo(xs::AbstractArray)
Expand Down Expand Up @@ -241,27 +244,39 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro
return fill(project.element(dx))
end

# Ref -- works like a zero-array, also allows restoration from a number:
ProjectTo(x::Ref) = ProjectTo{Ref}(; x=ProjectTo(x[]))
(project::ProjectTo{Ref})(dx::Ref) = Ref(project.x(dx[]))
(project::ProjectTo{Ref})(dx::Number) = Ref(project.x(dx))

function _projection_mismatch(axes_x::Tuple, size_dx::Tuple)
size_x = map(length, axes_x)
return DimensionMismatch(
"variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx"
)
end

#####
##### `Base`, part II: return of the Tangent
#####

# Ref
function ProjectTo(x::Ref)
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
if sub isa ProjectTo{<:AbstractZero}
return ProjectTo{NoTangent}()
else
return ProjectTo{Ref}(; type=typeof(x), x=sub)
end
end
(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x))
(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[]))
# Since this works like a zero-array in broadcasting, it should also accept a number:
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx))

#####
##### `LinearAlgebra`
#####

using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec

# Row vectors
function ProjectTo(x::LinearAlgebra.AdjointAbsVec)
sub = ProjectTo(parent(x))
return ProjectTo{Adjoint}(; parent=sub)
end
ProjectTo(x::AdjointAbsVec) = ProjectTo{Adjoint}(; parent=ProjectTo(parent(x)))
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
Expand All @@ -276,10 +291,7 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
return adjoint(project.parent(dy))
end

function ProjectTo(x::LinearAlgebra.TransposeAbsVec)
sub = ProjectTo(parent(x))
return ProjectTo{Transpose}(; parent=sub)
end
ProjectTo(x::LinearAlgebra.TransposeAbsVec) = ProjectTo{Transpose}(; parent=ProjectTo(parent(x)))
function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec)
return transpose(project.parent(transpose(dx)))
end
Expand All @@ -292,11 +304,7 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
end

# Diagonal
function ProjectTo(x::Diagonal)
sub = ProjectTo(x.diag)
sub isa ProjectTo{<:AbstractZero} && return sub # TODO not necc if Diagonal(NoTangent()) worked
return ProjectTo{Diagonal}(; diag=sub)
end
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))

Expand All @@ -308,7 +316,8 @@ for (SymHerm, chk, fun) in (
@eval begin
function ProjectTo(x::$SymHerm)
sub = ProjectTo(parent(x))
sub isa ProjectTo{<:AbstractZero} && return sub # TODO not necc if Hermitian(NoTangent()) etc. worked
# Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial:
sub isa ProjectTo{<:AbstractZero} && return sub
return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub)
end
function (project::ProjectTo{$SymHerm})(dx::AbstractArray)
Expand All @@ -333,12 +342,7 @@ end
# Triangular
for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg
@eval begin
function ProjectTo(x::$UL)
sub = ProjectTo(parent(x))
# TODO not nesc if UnitUpperTriangular(NoTangent()) etc. worked
sub isa ProjectTo{<:AbstractZero} && return sub
return ProjectTo{$UL}(; parent=sub)
end
ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx))
function (project::ProjectTo{$UL})(dx::Diagonal)
sub = project.parent
Expand Down
10 changes: 8 additions & 2 deletions test/differentials/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,15 @@
@test complex(z, z) === z
@test complex(z, 2.0) === Complex{Float64}(0.0, 2.0)
@test complex(1.5, z) === Complex{Float64}(1.5, 0.0)
@test Complex(z, 2.0) === Complex{Float64}(0.0, 2.0)
@test Complex(1.5, z) === Complex{Float64}(1.5, 0.0)
@test ComplexF64(z, 2.0) === Complex{Float64}(0.0, 2.0)
@test ComplexF64(1.5, z) === Complex{Float64}(1.5, 0.0)

@test convert(Int64, ZeroTangent()) == 0
@test convert(Float64, ZeroTangent()) == 0.0
@test convert(Bool, ZeroTangent()) === false
@test convert(Int64, ZeroTangent()) === Int64(0)
@test convert(Float32, ZeroTangent()) === 0.0f0
@test convert(ComplexF64, ZeroTangent()) === 0.0 + 0.0im
end

@testset "NoTangent" begin
Expand Down
50 changes: 42 additions & 8 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,24 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test ProjectTo(1.0f0 + 2im)(3) === 3.0f0 + 0im
@test ProjectTo(big(1.0))(2) === 2
@test ProjectTo(1.0)(2) === 2.0

# Tangents
ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(re=1, im=NoTangent())) === 1.0f0 + 0.0f0im
end

@testset "Dual" begin # some weird Real subtype that we should basically leave alone
@test ProjectTo(1.0)(Dual(1.0, 2.0)) isa Dual
@test ProjectTo(1.0)(Dual(1, 2)) isa Dual

# real & complex
@test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual}
@test ProjectTo(1.0 + 1im)(
Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))
) isa Complex{<:Dual}
@test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual

# Tangent
@test ProjectTo(Dual(1.0, 2.0))(Tangent{Dual}(; value=1.0)) isa Tangent
end

@testset "Base: arrays of numbers" begin
Expand Down Expand Up @@ -100,7 +108,7 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test ProjectTo(Bool[]) isa ProjectTo{NoTangent}
end

@testset "Base: zero-arrays & Ref" begin
@testset "Base: zero-arrays" begin
pzed = ProjectTo(fill(1.0))
@test pzed(fill(3.14)) == fill(3.14) # easy
@test pzed(fill(3)) == fill(3.0) # broadcast type change must not produce number
Expand All @@ -110,17 +118,26 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test_throws DimensionMismatch ProjectTo([1])(3.14 + im) # other array projectors don't accept numbers
@test_throws DimensionMismatch ProjectTo(hcat([1, 2]))(3.14)
@test pzed isa ProjectTo{AbstractArray}
end

@testset "Base: Ref" begin
pref = ProjectTo(Ref(2.0))
@test pref(Ref(3 + im))[] === 3.0
@test pref(4)[] === 4.0 # also re-wraps scalars
@test pref(Ref{Any}(5.0)) isa Base.RefValue{Float64}
@test pref(Ref(3 + im)).x === 3.0
@test pref(Tangent{Base.RefValue}(x = 3 + im)).x === 3.0
@test pref(4).x === 4.0 # also re-wraps scalars
@test pref(Ref{Any}(5.0)) isa Tangent{<:Base.RefValue}

pref2 = ProjectTo(Ref{Any}(6 + 7im))
@test pref2(Ref(8))[] === 8.0 + 0.0im
@test pref2(Ref(8)).x === 8.0 + 0.0im
@test pref2(Tangent{Base.RefValue}(x = 8)).x === 8.0 + 0.0im

prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents
@test prefvec(Ref(1:3)) isa Base.RefValue{Vector{ComplexF64}}
@test_throws DimensionMismatch prefvec(Ref{Any}(1:5))
@test prefvec(Ref(1:3)).x isa Vector{ComplexF64}
@test prefvec(Tangent{Base.RefValue}(x = 1:3)).x isa Vector{ComplexF64}
@test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(x = 1:5))

@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
end

#####
Expand Down Expand Up @@ -167,6 +184,9 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))

# issue #410
@test padj([NoTangent() NoTangent() NoTangent()]) === NoTangent()

@test ProjectTo(adj([true, false]))([1 2]) isa AbstractZero
@test ProjectTo(adj([[true], [false]])) isa ProjectTo{<:AbstractZero}
end

@testset "LinearAlgebra: dense structured matrices" begin
Expand Down Expand Up @@ -284,11 +304,12 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@testset "AbstractZero" begin
pz = ProjectTo(ZeroTangent())
pz(0) == NoTangent()
@test_broken pz(ZeroTangent()) === ZeroTangent() # not sure how NB this is to preserve
@test pz(ZeroTangent()) === ZeroTangent() # not sure how NB this is to preserve
@test pz(NoTangent()) === NoTangent()

pb = ProjectTo(true) # Bool is categorical
@test pb(2) === NoTangent()
@test pb(ZeroTangent()) isa AbstractZero # was a method ambiguity!

# all projectors preserve Zero, and specific type, via one fallback method:
@test ProjectTo(pi)(ZeroTangent()) === ZeroTangent()
Expand All @@ -305,6 +326,19 @@ Base.zero(x::Dual) = Dual(zero(x.value), zero(x.partial))
@test unthunk(pth) === 6.0 + 0.0im
end

@testset "Tangent" begin
x = 1:3.0
dx = Tangent{typeof(x)}(; step=0.1, ref=NoTangent());
@test ProjectTo(x)(dx) isa Tangent
@test ProjectTo(x)(dx).step === 0.1
@test ProjectTo(x)(dx).offset isa AbstractZero

pref = ProjectTo(Ref(2.0))
dy = Tangent{typeof(Ref(2.0))}(x = 3+4im)
@test pref(dy) isa Tangent{<:Base.RefValue}
@test pref(dy).x === 3.0
end

@testset "display" begin
@test repr(ProjectTo(1.1)) == "ProjectTo{Float64}()"
@test occursin("ProjectTo{AbstractArray}(element", repr(ProjectTo([1, 2, 3])))
Expand Down

1 comment on commit 8218c2c

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.