From f9304b32adbd911a7fb6c428cddcdaacd4d3a6ed Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 10 Nov 2021 13:56:27 +0000 Subject: [PATCH] Replace the uses of "differential" with "tangent" where appropriate (#513) --- docs/make.jl | 4 +- docs/src/api.md | 2 +- ...many_differentials.md => many_tangents.md} | 104 +++++++++--------- .../src/rule_author/converting_zygoterules.md | 2 +- docs/src/rule_author/intro.md | 2 +- .../superpowers/gradient_accumulation.md | 10 +- .../{differentials.md => tangents.md} | 8 +- docs/src/rule_author/writing_good_rules.md | 18 +-- src/ChainRulesCore.jl | 4 +- src/accumulation.jl | 2 +- src/projection.jl | 4 +- src/rules.jl | 4 +- src/tangent_arithmetic.jl | 6 +- src/tangent_types/abstract_tangent.jl | 14 +-- src/tangent_types/abstract_zero.jl | 10 +- src/tangent_types/notimplemented.jl | 18 +-- src/tangent_types/tangent.jl | 18 ++- src/tangent_types/thunks.jl | 2 +- 18 files changed, 115 insertions(+), 117 deletions(-) rename docs/src/design/{many_differentials.md => many_tangents.md} (50%) rename docs/src/rule_author/{differentials.md => tangents.md} (70%) diff --git a/docs/make.jl b/docs/make.jl index fea10dc67..9a92cf6fb 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -50,7 +50,7 @@ makedocs(; "Introduction" => "index.md", "How to use ChainRules as a rule author" => [ "Introduction" => "rule_author/intro.md", - "Tangent types" => "rule_author/differentials.md", + "Tangent types" => "rule_author/tangents.md", #"`frule` and `rrule`" => "rule_author/rules.md", # TODO: a complete example "Writing good rules" => "rule_author/writing_good_rules.md", "Testing your rules" => "rule_author/testing.md", @@ -77,7 +77,7 @@ makedocs(; ], "Design" => [ "Changing the Primal" => "design/changing_the_primal.md", - "Many Tangent Types" => "design/many_differentials.md", + "Many Tangent Types" => "design/many_tangents.md", ], "Videos" => "videos.md", "FAQ" => "FAQ.md", diff --git a/docs/src/api.md b/docs/src/api.md index 9b998c0af..5648058e0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -14,7 +14,7 @@ Pages = ["rule_definition_tools.jl"] Private = false ``` -## Differentials +## Tangent Types ```@autodocs Modules = [ChainRulesCore] Pages = [ diff --git a/docs/src/design/many_differentials.md b/docs/src/design/many_tangents.md similarity index 50% rename from docs/src/design/many_differentials.md rename to docs/src/design/many_tangents.md index 1e0583c3c..62fed1e4d 100644 --- a/docs/src/design/many_differentials.md +++ b/docs/src/design/many_tangents.md @@ -1,47 +1,47 @@ -# [Design Notes: The many-to-many relationship between differential types and primal types](@id manytypes) +# [Design Notes: The many-to-many relationship between tangent types and primal types](@id manytypes) -ChainRules has a system where one primal type (the type having its derivative taken) can have multiple possible differential types (the type of the derivative); and where one differential type can correspond to multiple primal types. -This is in-contrast to the Swift AD efforts, which has one differential type per primal type (Swift uses the term associated tangent type, rather than differential type). +ChainRules has a system where one primal type (the type having its derivative taken) can have multiple possible tangent types (the type of the derivative); and where one tangent type can correspond to multiple primal types. +This is in-contrast to the Swift AD efforts, which has one tangent type per primal type (Swift uses the term associated tangent type). -!!! terminology "differential and associated tangent type" - The use of “associated tangent type” in AD is not technically correct, as differentials naturally live in the [_cotangent_ plane](https://en.wikipedia.org/wiki/Cotangent_space) instead of the [tangent plane](https://en.wikipedia.org/wiki/Tangent_space). +!!! terminology "tangent and associated tangent type" + The use of “associated tangent type” in AD is not technically correct, as they live in the [_cotangent_ plane](https://en.wikipedia.org/wiki/Cotangent_space) instead of the [tangent plane](https://en.wikipedia.org/wiki/Tangent_space). However it is often reasonable for AD to treat the cotangent plane and tangent plane as the same thing, and this was an intentional choice by the Swift team. - Here we will just stick to the ChainRules terminology and only say “differential type” instead of “tangent type”. + In ChainRules we use the term “tangent type” to refer to both tangents and cotangents. -One thing to understand about differentials is that they have to form a [vector space](https://en.wikipedia.org/wiki/Vector_space) (or something very like them). +One thing to understand about tangents is that they have to form a [vector space](https://en.wikipedia.org/wiki/Vector_space) (or something very like them). They need to support addition to each other, they need a zero which doesn't change what it is added to, and they need to support scalar multiplication (this isn't really required, but it is handy for things like gradient descent). -Beyond being a vector space, differentials need to be able to be added to a primal value to get back another primal value. -Or roughly equivalently a differential is a difference between two primal values. +Beyond being a vector space, tangents need to be able to be added to a primal value to get back another primal value. +Or roughly equivalently a tangent is a difference between two primal values. One thing to note in this example is that the primal does not have to be a vector. -As an example, consider `DateTime`. A `DateTime` is not a vector space: there is no origin point, and `DateTime`s cannot be added to each other. The corresponding differential type is any subtype of `Period`, such as `Millisecond`, `Hour`, `Day` etc. +As an example, consider `DateTime`. A `DateTime` is not a vector space: there is no origin point, and `DateTime`s cannot be added to each other. The corresponding tangent type is any subtype of `Period`, such as `Millisecond`, `Hour`, `Day` etc. -## Natural differential +## Natural tangent -For a given primal type, we say a natural differential type is one which people would intuitively think of as representing the difference between two primal values. +For a given primal type, we say a natural tangent type is one which people would intuitively think of as representing the difference between two primal values. It tends to already exist outside of the context of AD. -So `Millisecond`, `Hour`, `Day` etc. are examples of _natural differentials_ for the `DateTime` primal. +So `Millisecond`, `Hour`, `Day` etc. are examples of _natural tangents_ for the `DateTime` primal. -Note here that we already have a one primal type to many differential types relationship. -We have `Millisecond` and `Hour` and `Day` all being valid differential types for `DateTime`. -In this case we _could_ convert them all to a single differential type, such as `Nanoseconds`, but that is not always a reasonable decision: we may run in to overflow, or lots of allocations if we need to use a `BigInt` to represent the number of `Nanosecond` since the start of the universe. +Note here that we already have a one primal type to many tangent types relationship. +We have `Millisecond` and `Hour` and `Day` all being valid tangent types for `DateTime`. +In this case we _could_ convert them all to a single tangent type, such as `Nanoseconds`, but that is not always a reasonable decision: we may run in to overflow, or lots of allocations if we need to use a `BigInt` to represent the number of `Nanosecond` since the start of the universe. For types with more complex semantics, such as array types, these considerations are much more important. -Natural differential types are the types people tend to think in, and thus the type they tend to write custom sensitivity rules in. -An important special case of natural differentials is when the primal type is a vector space (e.g. `Real`,`AbstractMatrix`) in which case it is _common_ for the natural differential type to be the same as the primal type. +Natural tangent types are the types people tend to think in, and thus the type they tend to write custom sensitivity rules in. +An important special case of natural tangents is when the primal type is a vector space (e.g. `Real`,`AbstractMatrix`) in which case it is _common_ for the natural tangent type to be the same as the primal type. One exception to this is `getindex`. -The ideal choice of differential type for `getindex` on a dense array would be some type of sparse array, due to the fact the derivative will have only one non-zero element. -This actually further brings us to a weirdness of differential types not actually being closed under addition, as it would be ideal for the sparse array to become a dense array if summed over all elements. +The ideal choice of tangent type for `getindex` on a dense array would be some type of sparse array, due to the fact the derivative will have only one non-zero element. +This actually further brings us to a weirdness of tangent types not actually being closed under addition, as it would be ideal for the sparse array to become a dense array if summed over all elements. -## Structural differential types +## Structural tangent types -AD cannot automatically determine natural differential types for a primal. For some types we may be able to declare manually their natural differential type. -Other types will not have natural differential types at all - e.g. `NamedTuple`, `Tuple`, `WebServer`, `Flux.Dense` - so we are destined to make some up. -So beyond _natural_ differential types, we also have _structural_ differential types. -ChainRules uses [`Tangent{P, <:NamedTuple}`](@ref Tangent) to represent a structural differential type corresponding to primal type `P`. +AD cannot automatically determine the natural tangent types for a primal. For some types we may be able to declare manually their natural tangent type. +Other types will not have natural tangent types at all - e.g. `NamedTuple`, `Tuple`, `WebServer`, `Flux.Dense` - so we are destined to make some up. +So beyond _natural_ tangent types, we also have _structural_ tangent types. +ChainRules uses [`Tangent{P, <:NamedTuple}`](@ref Tangent) to represent a structural tangent type corresponding to primal type `P`. [Zygote](https://github.com/FluxML/Zygote.jl/) v0.4 uses `NamedTuple`. -Structural differentials are derived from the structure of the input. +Structural tangents are derived from the structure of the input. Either automatically, as part of the AD, or manually, as part of a custom rule. Consider the structure of `DateTime`: @@ -53,7 +53,7 @@ DateTime value: Int64 63719890305605 ``` -The corresponding structural differential is: +The corresponding structural tangent is: ```julia Tangent{DateTime}( instant::Tangent{UTInstant{Millisecond}}( @@ -69,17 +69,17 @@ Tangent{DateTime}( In Swift `Int` is considered non-differentiable, which is quite reasonable; it doesn’t have a very good definition of the limit of a small step (as that would be some floating/fixed point type). `Int` is intrinsically discrete. It is commonly used for indexing, and if one takes a gradient step, say turning `x[2]` into `x[2.1]` then that is an error. - However, disallowing `Int` to be used as a differential means we cannot handle cases like `DateTime` having an inner field of milliseconds counted as an integer from the unix epoch or other cases where an integer is used as a convenience for computational efficiency. + However, disallowing `Int` to be used as a tangent means we cannot handle cases like `DateTime` having an inner field of milliseconds counted as an integer from the unix epoch or other cases where an integer is used as a convenience for computational efficiency. In the case where a custom sensitivity rule claims that there is a non-zero derivative for an `Int` argument that is being used for indexing, that code is simply wrong. We can’t handle incorrect code and trying to is a path toward madness. Julia, unlike Swift, is not well suited to handling rules about what you can and can’t do with particular types. -So the structural differential is another type of differential. -We must support both natural and structural differentials because AD can only create structural differentials (unless using custom sensitivity rules) and all custom sensitivities are only written in terms of natural differentials, as that is what is used in papers about derivatives. +So the structural tangent is another type of tangent. +We must support both natural and structural tangents because AD can only create structural tangents (unless using custom sensitivity rules) and all custom sensitivities are only written in terms of natural tangents, as that is what is used in papers about derivatives. -## Semi-structural differentials +## Semi-structural tangents -Where there is no natural differential type for the outermost type but there is for some of its fields, we call this a "semi-structural" differential. +Where there is no natural tangent type for the outermost type but there is for some of its fields, we call this a "semi-structural" tangent. Consider if we had a representation of a country's GDP as output by some continuous time model like a Gaussian Process, where that representation is as a sequence of `TimeSample`s structured as follows: @@ -101,7 +101,7 @@ TimeSample value: Float64 2.6e9 ``` -Thus we see the that structural differential would be: +Thus we see the that structural tangent would be: ```julia Tangent{TimeSample}( time::Tangent{DateTime}( @@ -115,8 +115,8 @@ Tangent{TimeSample}( ) ``` -But instead in the custom sensitivity rule we would write a semi-structured differential type. -Since there is not a natural differential type for `TimeSample` but there is for `DateTime`. +But instead in the custom sensitivity rule we would write a semi-structured tangent type. +Since there is not a natural tangent type for `TimeSample` but there is for `DateTime`. ```julia Tangent{TimeSample}( time::Day, @@ -124,12 +124,12 @@ Tangent{TimeSample}( ) ``` -So the rule author has written a structural differential with some fields that are natural differentials. +So the rule author has written a structural tangent with some fields that are natural tangents. Another related case is for types that overload `getproperty` such as `SVD` and `QR`. -In this case the structural differential will be based on the fields, but those fields do not always have an easy relation to what is actually used in math. +In this case the structural tangent will be based on the fields, but those fields do not always have an easy relation to what is actually used in math. For example, the `QR` type has fields `factors` and `t`, but we would more naturally think in terms of the properties `Q` and `R`. -So most rule authors would want to write semi-structural differentials based on the properties. +So most rule authors would want to write semi-structural tangents based on the properties. To return to the question of why ChainRules has `Tangent{P, <:NamedTuple}` whereas Zygote v0.4 just has `NamedTuple`, it relates to semi-structural derivatives, and being able to overload things more generally. If one knows that one has a semi-structural derivative based on property names, like `Tangent{QR}(Q=..., R=...)`, and one is adding it to the true structural derivative based on field names `Tangent{QR}(factors=..., τ=...)`, then we need to overload the addition operator to perform that correctly. @@ -139,45 +139,45 @@ In fact we can't actually overload addition at all for `NamedTuple` as that woul Another use of the primal being a type parameter is to catch errors. ChainRules disallows the addition of `Tangent{SVD}` to `Tangent{QR}` since in a correctly differentiated program that can never occur. -## Differentials types for computational efficiency +## Tangent types for computational efficiency -There is another kind of unnatural differential. +There is another kind of unnatural tangent. One that is for computational efficiency. ChainRules has [`Thunk`](@ref)s and [`InplaceableThunk`](@ref)s, which wrap the computation of a derivative and delays that work until it is needed, either via the derivative being added to something or being [`unthunk`](@ref)ed manually, thus saving time if it is never used. -Another differential type used for efficiency is [`ZeroTangent`](@ref) which represents the hard zero (in Zygote v0.4 this is `nothing`). +Another tangent type used for efficiency is [`ZeroTangent`](@ref) which represents the hard zero (in Zygote v0.4 this is `nothing`). For example the derivative of `f(x, y)=2x` with respect to `y` is `ZeroTangent()`. Add `ZeroTangent()` to anything, and one gets back the original thing without change. -We noted that all differentials need to be a vector space. +We noted that all tangents need to be a vector space. `ZeroTangent()` is the [trivial vector space](https://proofwiki.org/wiki/Definition:Trivial_Vector_Space). Further, add `ZeroTangent()` to any primal value (no matter the type) and you get back another value of the same primal type (the same value in fact). -So it meets the requirements of a differential type for *all* primal types. +So it meets the requirements of a tangent type for *all* primal types. `ZeroTangent` can save on memory (since we can avoid allocating anything) and on time (since performing the multiplication -`ZeroTangent` and `Thunk` are both examples of a differential type that is valid for multiple primal types. +`ZeroTangent` and `Thunk` are both examples of a tangent type that is valid for multiple primal types. ## Conclusion -Now, you have seen examples of both differential types that work for multiple primal types, and primal types that have multiple valid differential types. +Now, you have seen examples of both tangent types that work for multiple primal types, and primal types that have multiple valid tangent types. Semantically we can handle these very easily in julia. Just put in a few more dispatching on `+`. Multiple-dispatch is great like that. The down-side is our type-inference becomes hard. -If you have exactly 1 differential type for each primal type, you can very easily workout what all the types on your reverse pass will be - you don't really need type inference - but you lose so much expressibility. +If you have exactly 1 tangent type for each primal type, you can very easily workout what all the types on your reverse pass will be - you don't really need type inference - but you lose so much expressibility. ## Appendix: What Swift does I don't know how Swift is handling thunks, maybe they are not, maybe they have an optimizing compiler that can just slice out code-paths that don't lead to values that get used; maybe they have a language built in for lazy computation. -They are, as I understand it, handling `ZeroTangent` by requiring every differential type to define a `zero` method -- which it has since it is a vector space. +They are, as I understand it, handling `ZeroTangent` by requiring every tangent type to define a `zero` method -- which it has since it is a vector space. This costs memory and time, but probably not actually all that much. -With regards to handling multiple different differential types for one primal, like natural and structural derivatives, everything needs to be converted to the _canonical_ differential type of that primal. +With regards to handling multiple different tangent types for one primal, like natural and structural derivatives, everything needs to be converted to the _canonical_ tangent type of that primal. -As I understand it, things can be automatically converted by defining conversion protocols or something like that, so rule authors can return anything that has a conversion protocol to the canonical differential type of the primal. +As I understand it, things can be automatically converted by defining conversion protocols or something like that, so rule authors can return anything that has a conversion protocol to the canonical tangent type of the primal. However, it seems like this will run into problems. -Recall that the natural differential in the case of `getindex` on an `AbstractArray` was a sparse array. -But for say the standard dense `Array`, the only reasonable canonical differential type is also a dense `Array`. +Recall that the natural tangent in the case of `getindex` on an `AbstractArray` was a sparse array. +But for say the standard dense `Array`, the only reasonable canonical tangent type is also a dense `Array`. But if you convert a sparse array into a dense array you do giant allocations to fill in all the other entries with zero. -So this is the story about why we have many-to-many differential types in ChainRules. +So this is the story about why we have many-to-many tangent types in ChainRules. diff --git a/docs/src/rule_author/converting_zygoterules.md b/docs/src/rule_author/converting_zygoterules.md index 81750adc3..d0272a13a 100644 --- a/docs/src/rule_author/converting_zygoterules.md +++ b/docs/src/rule_author/converting_zygoterules.md @@ -44,7 +44,7 @@ See docs on [Constructors](@ref). ## Include the derivative with respect to the function object itself The `ZygoteRules.@adjoint` macro automagically[^1] inserts an extra `nothing` in the return for the function it generates to represent the derivative of output with respect to the function object. ChainRules as a philosophy avoids magic as much as possible, and thus require you to return it explicitly. -If it is a plain function (like `typeof(sin)`), then the differential will be [`NoTangent`](@ref). +If it is a plain function (like `typeof(sin)`), then the tangent will be [`NoTangent`](@ref). [^1]: unless you write it in functor form (i.e. `@adjoint (f::MyType)(args...)=...`), in that case like for `rrule` you need to include it explictly. diff --git a/docs/src/rule_author/intro.md b/docs/src/rule_author/intro.md index 44ec3eb12..a3b559d9d 100644 --- a/docs/src/rule_author/intro.md +++ b/docs/src/rule_author/intro.md @@ -2,7 +2,7 @@ This section of the docs will tell you everything you need to know about writing rules for your package. -It will help with understanding differential types, the anatomy of the `frule` and +It will help with understanding tangent types, the anatomy of the `frule` and the `rrule`, and provide tips on writing good rules, as well as how to test them easily using finite differences. diff --git a/docs/src/rule_author/superpowers/gradient_accumulation.md b/docs/src/rule_author/superpowers/gradient_accumulation.md index cfae113ba..68d4cc4aa 100644 --- a/docs/src/rule_author/superpowers/gradient_accumulation.md +++ b/docs/src/rule_author/superpowers/gradient_accumulation.md @@ -19,7 +19,7 @@ end The AD software must transform that into something which repeatedly sums up the gradient of each part: `X̄ = ā + b̄`. -This requires that all differential types `D` must implement `+`: `+(::D, ::D)::D`. +This requires that all tangent types `D` must implement `+`: `+(::D, ::D)::D`. We can note that in this particular case `ā` and `b̄` will both be arrays. This operation (`X̄ = ā + b̄`) will allocate one array to hold `ā`, another one to hold `b̄`, and a third one to hold `ā + b̄`. @@ -47,7 +47,7 @@ AD systems can generate `add!!` instead of `+` when accumulating gradient to tak ### Inplaceable Thunks (`InplaceableThunks`) avoid allocating values in the first place. We got down to two allocations from using [`add!!`](@ref), but can we do better? -We can think of having a differential type which acts on a partially accumulated result, to mutate it to contain its current value plus the partial derivative being accumulated. +We can think of having a tangent type which acts on a partially accumulated result, to mutate it to contain its current value plus the partial derivative being accumulated. Rather than having an actual computed value, we can just have a thing that will act on a value to perform the addition. Let's illustrate it with our example. @@ -79,9 +79,9 @@ The `val` field use a plain [`Thunk`](@ref) to avoid the computation (and thus a !!! note "Do we need both representations?" Right now every [`InplaceableThunk`](@ref) has two fields that need to be specified. The value form (represented as a the [`Thunk`](@ref) typed field), and the action form (represented as the `add!` field). - It is possible in a future version of ChainRulesCore.jl we will work out a clever way to find the zero differential for arbitrary primal values. - Given that, we could always just determine the value form from `inplaceable.add!(zero_differential(primal))`. - There are some technical difficulties in finding the zero differentials, but this may be solved at some point. + It is possible in a future version of ChainRulesCore.jl we will work out a clever way to find the zero tangent for arbitrary primal values. + Given that, we could always just determine the value form from `inplaceable.add!(zero_tangent(primal))`. + There are some technical difficulties in finding the zero tangents, but this may be solved at some point. The `+` operation on `InplaceableThunk`s is overloaded to [`unthunk`](@ref) that `val` field to get the value form. diff --git a/docs/src/rule_author/differentials.md b/docs/src/rule_author/tangents.md similarity index 70% rename from docs/src/rule_author/differentials.md rename to docs/src/rule_author/tangents.md index c2b4c9d7c..0cac01688 100644 --- a/docs/src/rule_author/differentials.md +++ b/docs/src/rule_author/tangents.md @@ -1,11 +1,11 @@ # [Tangent types](@id tangents) The values that come back from pullbacks or pushforwards are not always the same type as the input/outputs of the primal function. -They are differentials, which correspond roughly to something able to represent the difference between two values of the primal types. -A differential might be such a regular type, like a `Number`, or a `Matrix`, matching to the original type; +They are tangents, which correspond roughly to something able to represent the difference between two values of the primal types. +A tangent might be such a regular type, like a `Number`, or a `Matrix`, matching to the original type; or it might be one of the [`AbstractTangent`](@ref ChainRulesCore.AbstractTangent) subtypes. -Differentials support a number of operations. +Tangents support a number of operations. Most importantly: `+` and `*`, which let them act as mathematical objects. The most important `AbstractTangent`s when getting started are the ones about avoiding work: @@ -14,6 +14,6 @@ The most important `AbstractTangent`s when getting started are the ones about av - [`ZeroTangent`](@ref): It is a special representation of `0`. It does great things around avoiding expanding `Thunks` in addition. ### Other `AbstractTangent`s: - - [`Tangent{P}`](@ref Tangent): this is the differential for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type. + - [`Tangent{P}`](@ref Tangent): this is the tangent for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type. - [`NoTangent`](@ref): Zero-like, represents that the operation on this input is not differentiable. Its primal type is normally `Integer` or `Bool`. - [`InplaceableThunk`](@ref): it is like a `Thunk` but it can do in-place `add!`. diff --git a/docs/src/rule_author/writing_good_rules.md b/docs/src/rule_author/writing_good_rules.md index dab02a00c..2cfe91cf0 100644 --- a/docs/src/rule_author/writing_good_rules.md +++ b/docs/src/rule_author/writing_good_rules.md @@ -44,7 +44,7 @@ This woull be solved once [JuliaLang/julia#38241](https://github.com/JuliaLang/j ## Use `Thunk`s appropriately -If work is only required for one of the returned differentials, then it should be wrapped in a `@thunk` (potentially using a `begin`-`end` block). +If work is only required for one of the returned tangents, then it should be wrapped in a `@thunk` (potentially using a `begin`-`end` block). If there are multiple return values, their computation should almost always be wrapped in a `@thunk`. @@ -169,16 +169,16 @@ For example, if a primal type `P` overloads subtraction (`-(::P,::P)`) then that Common cases for types that represent a [vector-space](https://en.wikipedia.org/wiki/Vector_space) (e.g. `Float64`, `Array{Float64}`) is that the natural tangent type is the same as the primal type. However, this is not always the case. For example for a [`PDiagMat`](https://github.com/JuliaStats/PDMats.jl) a natural tangent is `Diagonal` since there is no requirement that a positive definite diagonal matrix has a positive definite tangent. -Another example is for a `DateTime`, any `Period` subtype, such as `Millisecond` or `Nanosecond` is a natural differential. +Another example is for a `DateTime`, any `Period` subtype, such as `Millisecond` or `Nanosecond` is a natural tangent. There are often many different natural tangent types for a given primal type. However, they are generally closely related and duck-type the same. For example, for most `AbstractArray` subtypes, most other `AbstractArray`s (of right size and element type) can be considered as natural tangent types. Not all types have natural tangent types. -For example there is no natural differential for a `Tuple`. +For example there is no natural tangent for a `Tuple`. It is not a `Tuple` since that doesn't have any method for `+`. Similar is true for many `struct`s. -For those cases there is only a structural differential. +For those cases there is only a structural tangent. ### Structural tangents @@ -216,10 +216,10 @@ In this sense they wrap either a natural or structural tangent. ## Use `@not_implemented` appropriately -You can use [`@not_implemented`](@ref) to mark missing differentials. -This is helpful if the function has multiple inputs or outputs, and you have worked out analytically and implemented some but not all differentials. +You can use [`@not_implemented`](@ref) to mark missing tangents. +This is helpful if the function has multiple inputs or outputs, and you have worked out analytically and implemented some but not all tangents. -It is recommended to include a link to a GitHub issue about the missing differential in the debugging information: +It is recommended to include a link to a GitHub issue about the missing tangent in the debugging information: ```julia @not_implemented( """ @@ -229,9 +229,9 @@ It is recommended to include a link to a GitHub issue about the missing differen ) ``` -Do not use `@not_implemented` if the differential does not exist mathematically (use `NoTangent()` instead). +Do not use `@not_implemented` if the tangent does not exist mathematically (use `NoTangent()` instead). -Note: [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) marks `@not_implemented` differentials as "test broken". +Note: [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) marks `@not_implemented` tangents as "test broken". ## Use rule definition tools diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index dbc0f057a..f9eaf59f6 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -11,10 +11,10 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented -export ProjectTo, canonicalize, unthunk # differential operations +export ProjectTo, canonicalize, unthunk # tangent operations export add!! # gradient accumulation operations export ignore_derivatives, @ignore_derivatives -# differentials +# tangents export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk include("compat.jl") diff --git a/src/accumulation.jl b/src/accumulation.jl index c9a38956a..6e186546e 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -39,7 +39,7 @@ end Returns true if `x` is suitable for for storing inplace accumulation of gradients. For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate -differential. +tangent. Wrapper array types do not need to overload this if they overload `Base.parent`, and are `is_inplaceable_destination` if and only if their parent array is. Other types should overload this, as it defaults to `false`. diff --git a/src/projection.jl b/src/projection.jl index 2986f388f..78a9b389b 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,7 +1,7 @@ """ (p::ProjectTo{T})(dx) -Projects the differential `dx` onto a specific tangent space. +Projects the tangent `dx` onto a specific tangent space. The type `T` is meant to encode the largest acceptable space, so usually this enforces `p(dx)::T`. But some subspaces which aren't subtypes of `T` may @@ -80,7 +80,7 @@ _maybe_call(f, x) = f """ ProjectTo(x) -Returns a `ProjectTo{T}` functor which projects a differential `dx` onto the +Returns a `ProjectTo{T}` functor which projects a tangent `dx` onto the relevant tangent space for `x`. Custom `ProjectTo` methods are provided for many subtypes of `Number` (to e.g. ensure precision), diff --git a/src/rules.jl b/src/rules.jl index d15655b24..d99e54a01 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -5,7 +5,7 @@ Expressing the output of `f(x...)` as `Ω`, return the tuple: (Ω, ΔΩ) -The second return value is the differential w.r.t. the output. +The second return value is the tangent w.r.t. the output. If no method matching `frule((Δf, Δx...), f, x...)` has been defined, then return `nothing`. @@ -87,7 +87,7 @@ as `Ω`, return the tuple: (Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...)) Where the second return value is the the propagation rule or pullback. -It takes in differentials corresponding to the outputs (`x̄₁, x̄₂, ...`), +It takes in cotangents corresponding to the outputs (`x̄₁, x̄₂, ...`), and `s̄elf`, the internal values of the function itself (for closures) If no method matching `rrule(f, xs...)` has been defined, then return `nothing`. diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index a6609ac33..439f0ac8f 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -1,5 +1,5 @@ #== -All differentials need to define + and *. +All tangents need to define + and *. That happens here. We just use @eval to define all the combinations for AbstractTangent @@ -148,8 +148,8 @@ Base.:+(a::Tangent{P}, b::P) where {P} = b + a Base.:-(tangent::Tangent{P}) where {P} = map(-, tangent) # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful -# In general one doesn't have to represent multiplications of 2 differentials -# Only of a differential and a scaling factor (generally `Real`) +# In general one doesn't have to represent multiplications of 2 tangents +# Only of a tangent and a scaling factor (generally `Real`) for T in (:Number,) @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) diff --git a/src/tangent_types/abstract_tangent.jl b/src/tangent_types/abstract_tangent.jl index 029d03af9..2f5d0a492 100644 --- a/src/tangent_types/abstract_tangent.jl +++ b/src/tangent_types/abstract_tangent.jl @@ -7,15 +7,15 @@ The subtypes of `AbstractTangent` define a custom \"algebra\" for chain rule evaluation that attempts to factor various features like complex derivative support, broadcast fusion, zero-elision, etc. into nicely separated parts. -In general a differential type is the type of a derivative of a value. +In general a tangent type is the type of a derivative of a value. The type of the value is for contrast called the primal type. Differential types correspond to primal types, although the relation is not one-to-one. -Subtypes of `AbstractTangent` are not the only differential types. +Subtypes of `AbstractTangent` are not the only tangent types. In fact for the most common primal types, such as `Real` or `AbstractArray{Real}` the -the differential type is the same as the primal type. +the tangent type is the same as the primal type. -In a circular definition: the most important property of a differential is that it should -be able to be added (by defining `+`) to another differential of the same primal type. +In a circular definition: the most important property of a tangent is that it should +be able to be added (by defining `+`) to another tangent of the same primal type. That allows for gradients to be accumulated. It generally also should be able to be added to a primal to give back another primal, as @@ -23,8 +23,8 @@ this facilitates gradient descent. All subtypes of `AbstractTangent` implement the following operations: - - `+(a, b)`: linearly combine differential `a` and differential `b` - - `*(a, b)`: multiply the differential `b` by the scaling factor `a` + - `+(a, b)`: linearly combine tangent `a` and tangent `b` + - `*(a, b)`: multiply the tangent `b` by the scaling factor `a` - `Base.zero(x) = ZeroTangent()`: a zero. Further, they often implement other linear operators, such as `conj`, `adjoint`, `dot`. diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 50a5efbd2..bc1cbd161 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -1,7 +1,7 @@ """ AbstractZero <: AbstractTangent -Supertype for zero-like differentials—i.e., differentials that act like zero when +Supertype for zero-like tangents—i.e., tangents that act like zero when added or multiplied to other values. If an AD system encounters a propagator that takes as input only subtypes of `AbstractZero`, then it can stop performing AD operations. @@ -39,7 +39,7 @@ Base.reshape(z::AbstractZero, size...) = z """ ZeroTangent() <: AbstractZero -The additive identity for differentials. +The additive identity for tangents. This is basically the same as `0`. A derivative of `ZeroTangent()` does not propagate through the primal function. """ @@ -53,15 +53,15 @@ Base.zero(::Type{<:AbstractTangent}) = ZeroTangent() """ NoTangent() <: AbstractZero -This differential indicates that the derivative does not exist. -It is the differential for primal types that are not differentiable, +This tangent indicates that the derivative does not exist. +It is the tangent type for primal types that are not differentiable, such as integers or booleans (when they are not being used to represent floating-point values). The only valid way to perturb such values is to not change them at all. As a consequence, `NoTangent` is functionally identical to `ZeroTangent()`, but it provides additional semantic information. -Adding this differential to a primal is generally wrong: gradient-based +Adding `NoTangent()` to a primal is generally wrong: gradient-based methods cannot be used to optimize over discrete variables. An optimization package making use of this might want to check for such a case. diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index a6b9cc5f9..661308a11 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -1,18 +1,18 @@ """ @not_implemented(info) -Create a differential that indicates that the derivative is not implemented. +Create a tangent that indicates that the derivative is not implemented. -The `info` should be useful information about the missing differential for debugging. +The `info` should be useful information about the missing tangent for debugging. !!! note This macro should be used only if the automatic differentiation would error otherwise. It is mostly useful if the function has multiple inputs or outputs, - and one has worked out analytically and implemented some but not all differentials. + and one has worked out analytically and implemented some but not all tangents. !!! note It is good practice to include a link to a GitHub issue about the missing - differential in the debugging information. + tangent in the debugging information. """ macro not_implemented(info) return :(NotImplemented($__module__, $(QuoteNode(__source__)), $(esc(info)))) @@ -21,7 +21,7 @@ end """ NotImplemented -This differential indicates that the derivative is not implemented. +This tangent indicates that the derivative is not implemented. It is generally best to construct this using the [`@not_implemented`](@ref) macro, which will automatically insert the source module and file location. @@ -34,11 +34,11 @@ end # required for `@scalar_rule` # (together with `conj(x::AbstractTangent) = x` and the definitions in -# differential_arithmetic.jl) +# tangent_arithmetic.jl) Base.Broadcast.broadcastable(x::NotImplemented) = Ref(x) # throw error with debugging information for other standard information -# (`+`, `-`, `*`, and `dot` are defined in differential_arithmetic.jl) +# (`+`, `-`, `*`, and `dot` are defined in tangent_arithmetic.jl) Base.:/(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) @@ -48,7 +48,7 @@ function Base.zero(::Type{<:NotImplemented}) return throw( NotImplementedException( @not_implemented( - "`zero` is not defined for missing differentials of type `NotImplemented`" + "`zero` is not defined for missing tangents of type `NotImplemented`" ) ), ) @@ -77,7 +77,7 @@ function NotImplementedException(x::NotImplemented) end function Base.showerror(io::IO, e::NotImplementedException) - print(io, "differential not implemented @ ", e.mod, " ", e.source) + print(io, "tangent not implemented @ ", e.mod, " ", e.source) if e.info !== nothing print(io, "\nInfo: ", e.info) end diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index ec7c64448..38c99eefa 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -1,11 +1,11 @@ """ Tangent{P, T} <: AbstractTangent -This type represents the differential for a `struct`/`NamedTuple`, or `Tuple`. -`P` is the the corresponding primal type that this is a differential for. +This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`. +`P` is the the corresponding primal type that this is a tangent for. `Tangent{P}` should have fields (technically properties), that match to a subset of the -fields of the primal type; and each should be a differential type matching to the primal +fields of the primal type; and each should be a tangent type matching to the primal type of that field. Fields of the P that are not present in the Tangent are treated as `Zero`. @@ -23,7 +23,7 @@ function is provided. """ struct Tangent{P,T} <: AbstractTangent # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict - # (but potentially a different one, as it doesn't contain differentials) + # (but potentially a different one, as it doesn't contain tangents) backing::T function Tangent{P,T}(backing) where {P,T} @@ -310,7 +310,7 @@ elementwise_add(a::Dict, b::Dict) = merge(+, a, b) struct PrimalAdditionFailedException{P} <: Exception primal::P - differential::Tangent{P} + tangent::Tangent{P} original::Exception end @@ -318,15 +318,13 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} println(io, "Could not construct $P after addition.") println(io, "This probably means no default constructor is defined.") println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")"; color=:blue) + printstyled(io, "$P(", join(propertynames(err.tangent), ", "), ")"; color=:blue) println(io, "\nor overload") printstyled( - io, - "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; - color=:blue, + io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.tangent)))"; color=:blue ) println(io, "\nor overload") - printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) + printstyled(io, "Base.:+(::$P, ::$(typeof(err.tangent)))"; color=:blue) println(io, "\nOriginal Exception:") printstyled(io, err.original; color=:yellow) return println(io) diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index e065bea62..a39c8416a 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -153,7 +153,7 @@ Base.transpose(x::AbstractThunk) = @thunk(transpose(unthunk(x))) """ Thunk(()->v) A thunk is a deferred computation. -It wraps a zero argument closure that when invoked returns a differential. +It wraps a zero argument closure that when invoked returns a tangent. `@thunk(v)` is a macro that expands into `Thunk(()->v)`. To evaluate the wrapped closure, call [`unthunk`](@ref) which is a no-op when the