Skip to content

Commit

Permalink
Replace the uses of "differential" with "tangent" where appropriate (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mzgubic authored Nov 10, 2021
1 parent 4d27d7e commit f9304b3
Show file tree
Hide file tree
Showing 18 changed files with 115 additions and 117 deletions.
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ makedocs(;
"Introduction" => "index.md",
"How to use ChainRules as a rule author" => [
"Introduction" => "rule_author/intro.md",
"Tangent types" => "rule_author/differentials.md",
"Tangent types" => "rule_author/tangents.md",
#"`frule` and `rrule`" => "rule_author/rules.md", # TODO: a complete example
"Writing good rules" => "rule_author/writing_good_rules.md",
"Testing your rules" => "rule_author/testing.md",
Expand All @@ -77,7 +77,7 @@ makedocs(;
],
"Design" => [
"Changing the Primal" => "design/changing_the_primal.md",
"Many Tangent Types" => "design/many_differentials.md",
"Many Tangent Types" => "design/many_tangents.md",
],
"Videos" => "videos.md",
"FAQ" => "FAQ.md",
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Pages = ["rule_definition_tools.jl"]
Private = false
```

## Differentials
## Tangent Types
```@autodocs
Modules = [ChainRulesCore]
Pages = [
Expand Down

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/src/rule_author/converting_zygoterules.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ See docs on [Constructors](@ref).
## Include the derivative with respect to the function object itself
The `ZygoteRules.@adjoint` macro automagically[^1] inserts an extra `nothing` in the return for the function it generates to represent the derivative of output with respect to the function object.
ChainRules as a philosophy avoids magic as much as possible, and thus require you to return it explicitly.
If it is a plain function (like `typeof(sin)`), then the differential will be [`NoTangent`](@ref).
If it is a plain function (like `typeof(sin)`), then the tangent will be [`NoTangent`](@ref).


[^1]: unless you write it in functor form (i.e. `@adjoint (f::MyType)(args...)=...`), in that case like for `rrule` you need to include it explictly.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/rule_author/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This section of the docs will tell you everything you need to know about writing rules for your package.

It will help with understanding differential types, the anatomy of the `frule` and
It will help with understanding tangent types, the anatomy of the `frule` and
the `rrule`, and provide tips on writing good rules, as well as how to test them easily
using finite differences.

Expand Down
10 changes: 5 additions & 5 deletions docs/src/rule_author/superpowers/gradient_accumulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end
The AD software must transform that into something which repeatedly sums up the gradient of each part:
`X̄ = ā + b̄`.

This requires that all differential types `D` must implement `+`: `+(::D, ::D)::D`.
This requires that all tangent types `D` must implement `+`: `+(::D, ::D)::D`.

We can note that in this particular case `` and `` will both be arrays.
This operation (`X̄ = ā + b̄`) will allocate one array to hold `ā`, another one to hold ``, and a third one to hold `ā + b̄`.
Expand Down Expand Up @@ -47,7 +47,7 @@ AD systems can generate `add!!` instead of `+` when accumulating gradient to tak

### Inplaceable Thunks (`InplaceableThunks`) avoid allocating values in the first place.
We got down to two allocations from using [`add!!`](@ref), but can we do better?
We can think of having a differential type which acts on a partially accumulated result, to mutate it to contain its current value plus the partial derivative being accumulated.
We can think of having a tangent type which acts on a partially accumulated result, to mutate it to contain its current value plus the partial derivative being accumulated.
Rather than having an actual computed value, we can just have a thing that will act on a value to perform the addition.
Let's illustrate it with our example.

Expand Down Expand Up @@ -79,9 +79,9 @@ The `val` field use a plain [`Thunk`](@ref) to avoid the computation (and thus a
!!! note "Do we need both representations?"
Right now every [`InplaceableThunk`](@ref) has two fields that need to be specified.
The value form (represented as a the [`Thunk`](@ref) typed field), and the action form (represented as the `add!` field).
It is possible in a future version of ChainRulesCore.jl we will work out a clever way to find the zero differential for arbitrary primal values.
Given that, we could always just determine the value form from `inplaceable.add!(zero_differential(primal))`.
There are some technical difficulties in finding the zero differentials, but this may be solved at some point.
It is possible in a future version of ChainRulesCore.jl we will work out a clever way to find the zero tangent for arbitrary primal values.
Given that, we could always just determine the value form from `inplaceable.add!(zero_tangent(primal))`.
There are some technical difficulties in finding the zero tangents, but this may be solved at some point.


The `+` operation on `InplaceableThunk`s is overloaded to [`unthunk`](@ref) that `val` field to get the value form.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# [Tangent types](@id tangents)

The values that come back from pullbacks or pushforwards are not always the same type as the input/outputs of the primal function.
They are differentials, which correspond roughly to something able to represent the difference between two values of the primal types.
A differential might be such a regular type, like a `Number`, or a `Matrix`, matching to the original type;
They are tangents, which correspond roughly to something able to represent the difference between two values of the primal types.
A tangent might be such a regular type, like a `Number`, or a `Matrix`, matching to the original type;
or it might be one of the [`AbstractTangent`](@ref ChainRulesCore.AbstractTangent) subtypes.

Differentials support a number of operations.
Tangents support a number of operations.
Most importantly: `+` and `*`, which let them act as mathematical objects.

The most important `AbstractTangent`s when getting started are the ones about avoiding work:
Expand All @@ -14,6 +14,6 @@ The most important `AbstractTangent`s when getting started are the ones about av
- [`ZeroTangent`](@ref): It is a special representation of `0`. It does great things around avoiding expanding `Thunks` in addition.

### Other `AbstractTangent`s:
- [`Tangent{P}`](@ref Tangent): this is the differential for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type.
- [`Tangent{P}`](@ref Tangent): this is the tangent for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type.
- [`NoTangent`](@ref): Zero-like, represents that the operation on this input is not differentiable. Its primal type is normally `Integer` or `Bool`.
- [`InplaceableThunk`](@ref): it is like a `Thunk` but it can do in-place `add!`.
18 changes: 9 additions & 9 deletions docs/src/rule_author/writing_good_rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ This woull be solved once [JuliaLang/julia#38241](https://github.com/JuliaLang/j

## Use `Thunk`s appropriately

If work is only required for one of the returned differentials, then it should be wrapped in a `@thunk` (potentially using a `begin`-`end` block).
If work is only required for one of the returned tangents, then it should be wrapped in a `@thunk` (potentially using a `begin`-`end` block).

If there are multiple return values, their computation should almost always be wrapped in a `@thunk`.

Expand Down Expand Up @@ -169,16 +169,16 @@ For example, if a primal type `P` overloads subtraction (`-(::P,::P)`) then that
Common cases for types that represent a [vector-space](https://en.wikipedia.org/wiki/Vector_space) (e.g. `Float64`, `Array{Float64}`) is that the natural tangent type is the same as the primal type.
However, this is not always the case.
For example for a [`PDiagMat`](https://github.com/JuliaStats/PDMats.jl) a natural tangent is `Diagonal` since there is no requirement that a positive definite diagonal matrix has a positive definite tangent.
Another example is for a `DateTime`, any `Period` subtype, such as `Millisecond` or `Nanosecond` is a natural differential.
Another example is for a `DateTime`, any `Period` subtype, such as `Millisecond` or `Nanosecond` is a natural tangent.
There are often many different natural tangent types for a given primal type.
However, they are generally closely related and duck-type the same.
For example, for most `AbstractArray` subtypes, most other `AbstractArray`s (of right size and element type) can be considered as natural tangent types.

Not all types have natural tangent types.
For example there is no natural differential for a `Tuple`.
For example there is no natural tangent for a `Tuple`.
It is not a `Tuple` since that doesn't have any method for `+`.
Similar is true for many `struct`s.
For those cases there is only a structural differential.
For those cases there is only a structural tangent.

### Structural tangents

Expand Down Expand Up @@ -216,10 +216,10 @@ In this sense they wrap either a natural or structural tangent.

## Use `@not_implemented` appropriately

You can use [`@not_implemented`](@ref) to mark missing differentials.
This is helpful if the function has multiple inputs or outputs, and you have worked out analytically and implemented some but not all differentials.
You can use [`@not_implemented`](@ref) to mark missing tangents.
This is helpful if the function has multiple inputs or outputs, and you have worked out analytically and implemented some but not all tangents.

It is recommended to include a link to a GitHub issue about the missing differential in the debugging information:
It is recommended to include a link to a GitHub issue about the missing tangent in the debugging information:
```julia
@not_implemented(
"""
Expand All @@ -229,9 +229,9 @@ It is recommended to include a link to a GitHub issue about the missing differen
)
```

Do not use `@not_implemented` if the differential does not exist mathematically (use `NoTangent()` instead).
Do not use `@not_implemented` if the tangent does not exist mathematically (use `NoTangent()` instead).

Note: [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) marks `@not_implemented` differentials as "test broken".
Note: [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) marks `@not_implemented` tangents as "test broken".

## Use rule definition tools

Expand Down
4 changes: 2 additions & 2 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
export ProjectTo, canonicalize, unthunk # differential operations
export ProjectTo, canonicalize, unthunk # tangent operations
export add!! # gradient accumulation operations
export ignore_derivatives, @ignore_derivatives
# differentials
# tangents
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk

include("compat.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end
Returns true if `x` is suitable for for storing inplace accumulation of gradients.
For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate
differential.
tangent.
Wrapper array types do not need to overload this if they overload `Base.parent`, and are
`is_inplaceable_destination` if and only if their parent array is.
Other types should overload this, as it defaults to `false`.
Expand Down
4 changes: 2 additions & 2 deletions src/projection.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
(p::ProjectTo{T})(dx)
Projects the differential `dx` onto a specific tangent space.
Projects the tangent `dx` onto a specific tangent space.
The type `T` is meant to encode the largest acceptable space, so usually
this enforces `p(dx)::T`. But some subspaces which aren't subtypes of `T` may
Expand Down Expand Up @@ -80,7 +80,7 @@ _maybe_call(f, x) = f
"""
ProjectTo(x)
Returns a `ProjectTo{T}` functor which projects a differential `dx` onto the
Returns a `ProjectTo{T}` functor which projects a tangent `dx` onto the
relevant tangent space for `x`.
Custom `ProjectTo` methods are provided for many subtypes of `Number` (to e.g. ensure precision),
Expand Down
4 changes: 2 additions & 2 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Expressing the output of `f(x...)` as `Ω`, return the tuple:
(Ω, ΔΩ)
The second return value is the differential w.r.t. the output.
The second return value is the tangent w.r.t. the output.
If no method matching `frule((Δf, Δx...), f, x...)` has been defined, then return `nothing`.
Expand Down Expand Up @@ -87,7 +87,7 @@ as `Ω`, return the tuple:
(Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...))
Where the second return value is the the propagation rule or pullback.
It takes in differentials corresponding to the outputs (`x̄₁, x̄₂, ...`),
It takes in cotangents corresponding to the outputs (`x̄₁, x̄₂, ...`),
and `s̄elf`, the internal values of the function itself (for closures)
If no method matching `rrule(f, xs...)` has been defined, then return `nothing`.
Expand Down
6 changes: 3 additions & 3 deletions src/tangent_arithmetic.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#==
All differentials need to define + and *.
All tangents need to define + and *.
That happens here.
We just use @eval to define all the combinations for AbstractTangent
Expand Down Expand Up @@ -148,8 +148,8 @@ Base.:+(a::Tangent{P}, b::P) where {P} = b + a
Base.:-(tangent::Tangent{P}) where {P} = map(-, tangent)

# We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful
# In general one doesn't have to represent multiplications of 2 differentials
# Only of a differential and a scaling factor (generally `Real`)
# In general one doesn't have to represent multiplications of 2 tangents
# Only of a tangent and a scaling factor (generally `Real`)
for T in (:Number,)
@eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent)
@eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent)
Expand Down
14 changes: 7 additions & 7 deletions src/tangent_types/abstract_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@ The subtypes of `AbstractTangent` define a custom \"algebra\" for chain
rule evaluation that attempts to factor various features like complex derivative
support, broadcast fusion, zero-elision, etc. into nicely separated parts.
In general a differential type is the type of a derivative of a value.
In general a tangent type is the type of a derivative of a value.
The type of the value is for contrast called the primal type.
Differential types correspond to primal types, although the relation is not one-to-one.
Subtypes of `AbstractTangent` are not the only differential types.
Subtypes of `AbstractTangent` are not the only tangent types.
In fact for the most common primal types, such as `Real` or `AbstractArray{Real}` the
the differential type is the same as the primal type.
the tangent type is the same as the primal type.
In a circular definition: the most important property of a differential is that it should
be able to be added (by defining `+`) to another differential of the same primal type.
In a circular definition: the most important property of a tangent is that it should
be able to be added (by defining `+`) to another tangent of the same primal type.
That allows for gradients to be accumulated.
It generally also should be able to be added to a primal to give back another primal, as
this facilitates gradient descent.
All subtypes of `AbstractTangent` implement the following operations:
- `+(a, b)`: linearly combine differential `a` and differential `b`
- `*(a, b)`: multiply the differential `b` by the scaling factor `a`
- `+(a, b)`: linearly combine tangent `a` and tangent `b`
- `*(a, b)`: multiply the tangent `b` by the scaling factor `a`
- `Base.zero(x) = ZeroTangent()`: a zero.
Further, they often implement other linear operators, such as `conj`, `adjoint`, `dot`.
Expand Down
10 changes: 5 additions & 5 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
AbstractZero <: AbstractTangent
Supertype for zero-like differentials—i.e., differentials that act like zero when
Supertype for zero-like tangents—i.e., tangents that act like zero when
added or multiplied to other values.
If an AD system encounters a propagator that takes as input only subtypes of `AbstractZero`,
then it can stop performing AD operations.
Expand Down Expand Up @@ -39,7 +39,7 @@ Base.reshape(z::AbstractZero, size...) = z
"""
ZeroTangent() <: AbstractZero
The additive identity for differentials.
The additive identity for tangents.
This is basically the same as `0`.
A derivative of `ZeroTangent()` does not propagate through the primal function.
"""
Expand All @@ -53,15 +53,15 @@ Base.zero(::Type{<:AbstractTangent}) = ZeroTangent()
"""
NoTangent() <: AbstractZero
This differential indicates that the derivative does not exist.
It is the differential for primal types that are not differentiable,
This tangent indicates that the derivative does not exist.
It is the tangent type for primal types that are not differentiable,
such as integers or booleans (when they are not being used to represent
floating-point values).
The only valid way to perturb such values is to not change them at all.
As a consequence, `NoTangent` is functionally identical to `ZeroTangent()`,
but it provides additional semantic information.
Adding this differential to a primal is generally wrong: gradient-based
Adding `NoTangent()` to a primal is generally wrong: gradient-based
methods cannot be used to optimize over discrete variables.
An optimization package making use of this might want to check for such a case.
Expand Down
18 changes: 9 additions & 9 deletions src/tangent_types/notimplemented.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""
@not_implemented(info)
Create a differential that indicates that the derivative is not implemented.
Create a tangent that indicates that the derivative is not implemented.
The `info` should be useful information about the missing differential for debugging.
The `info` should be useful information about the missing tangent for debugging.
!!! note
This macro should be used only if the automatic differentiation would error
otherwise. It is mostly useful if the function has multiple inputs or outputs,
and one has worked out analytically and implemented some but not all differentials.
and one has worked out analytically and implemented some but not all tangents.
!!! note
It is good practice to include a link to a GitHub issue about the missing
differential in the debugging information.
tangent in the debugging information.
"""
macro not_implemented(info)
return :(NotImplemented($__module__, $(QuoteNode(__source__)), $(esc(info))))
Expand All @@ -21,7 +21,7 @@ end
"""
NotImplemented
This differential indicates that the derivative is not implemented.
This tangent indicates that the derivative is not implemented.
It is generally best to construct this using the [`@not_implemented`](@ref) macro,
which will automatically insert the source module and file location.
Expand All @@ -34,11 +34,11 @@ end

# required for `@scalar_rule`
# (together with `conj(x::AbstractTangent) = x` and the definitions in
# differential_arithmetic.jl)
# tangent_arithmetic.jl)
Base.Broadcast.broadcastable(x::NotImplemented) = Ref(x)

# throw error with debugging information for other standard information
# (`+`, `-`, `*`, and `dot` are defined in differential_arithmetic.jl)
# (`+`, `-`, `*`, and `dot` are defined in tangent_arithmetic.jl)
Base.:/(x::NotImplemented, ::Any) = throw(NotImplementedException(x))
Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x))
Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x))
Expand All @@ -48,7 +48,7 @@ function Base.zero(::Type{<:NotImplemented})
return throw(
NotImplementedException(
@not_implemented(
"`zero` is not defined for missing differentials of type `NotImplemented`"
"`zero` is not defined for missing tangents of type `NotImplemented`"
)
),
)
Expand Down Expand Up @@ -77,7 +77,7 @@ function NotImplementedException(x::NotImplemented)
end

function Base.showerror(io::IO, e::NotImplementedException)
print(io, "differential not implemented @ ", e.mod, " ", e.source)
print(io, "tangent not implemented @ ", e.mod, " ", e.source)
if e.info !== nothing
print(io, "\nInfo: ", e.info)
end
Expand Down
Loading

0 comments on commit f9304b3

Please sign in to comment.