Skip to content
This repository was archived by the owner on Mar 30, 2022. It is now read-only.

Update differentiation docs to remove CotangentVector. #191

Merged
merged 2 commits into from
Nov 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/DifferentiableFunctions.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ extension Layer {
/// gradients at the layer and at the input, respectively.
func appliedForBackpropagation(to input: Input)
-> (output: Output,
backpropagator: (_ direction: Output.CotangentVector)
-> (layerGradient: CotangentVector, inputGradient: Input.CotangentVector)) {
backpropagator: (_ direction: Output.TangentVector)
-> (layerGradient: TangentVector, inputGradient: Input.TangentVector)) {
let (out, pullback) = valueWithPullback(at: input) { layer, input in
return layer(input)
}
Expand Down Expand Up @@ -347,13 +347,13 @@ Internally, `differentiableFunction(from:)` is defined just using the
```swift
/// Returns a differentiable function given its derivative.
public func differentiableFunction<T: Differentiable, R: Differentiable>(
from vjp: @escaping (T) -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector)
from vjp: @escaping (T) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector)
) -> @differentiable (T) -> R {
func original(_ x: T) -> R {
return vjp(x).value
}
@differentiating(original)
func derivative(_ x: T) -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) {
func derivative(_ x: T) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector) {
return vjp(x)
}
return original
Expand Down Expand Up @@ -431,8 +431,8 @@ defined:
// functions. It simply returns `pullback(1)`.
func gradient<T, R>(
at x: T, in f: @differentiable (T) -> R
) -> T.CotangentVector
where T: Differentiable, R: FloatingPoint & Differentiable, R.CotangentVector == R
) -> T.TangentVector
where T: Differentiable, R: FloatingPoint & Differentiable, R.TangentVector == R
{
let (value, pullback) = valueWithPullback(at: x, in: f)
return pullback(R(1))
Expand Down
59 changes: 12 additions & 47 deletions docs/DifferentiableTypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ print(𝛁v)
// Vector(x: 2.0, y: 0.0, z: 0.0)
```

A `Differentiable`-conforming type may have stored properties that are not meant to have a derivative with respect to `self`. Use the `@noDerivative` attribute to mark those properties; they will not have a corresponding entry in the synthesized `TangentVector`, `CotangentVector`, and `AllDifferentiableVariables` struct types.
A `Differentiable`-conforming type may have stored properties that are not meant to have a derivative with respect to `self`. Use the `@noDerivative` attribute to mark those properties; they will not have a corresponding entry in the synthesized `TangentVector` and `AllDifferentiableVariables` struct types.

Here’s an example deep learning layer with some `@noDerivative` properties:

Expand Down Expand Up @@ -103,30 +103,19 @@ public protocol Differentiable {
/// The tangent bundle of this differentiable manifold.
associatedtype TangentVector: AdditiveArithmetic & Differentiable
where TangentVector.TangentVector == TangentVector,
TangentVector.CotangentVector == CotangentVector,
TangentVector.AllDifferentiableVariables == TangentVector

/// The cotangent bundle of this differentiable manifold.
associatedtype CotangentVector: AdditiveArithmetic & Differentiable
where CotangentVector.TangentVector == CotangentVector,
CotangentVector.CotangentVector == TangentVector,
CotangentVector.AllDifferentiableVariables == CotangentVector

/// The type of all differentiable variables in this type.
associatedtype AllDifferentiableVariables: Differentiable
where AllDifferentiableVariables.AllDifferentiableVariables == AllDifferentiableVariables,
AllDifferentiableVariables.TangentVector == TangentVector,
AllDifferentiableVariables.CotangentVector == CotangentVector

/// All differentiable variables in this type.
var allDifferentiableVariables: AllDifferentiableVariables { get }

/// Returns `self` moved along the value space towards the given tangent vector.
/// In Riemannian geometry (mathematics), this represents exponential map.
func moved(along direction: TangentVector) -> Self

/// Converts a cotangent vector to its corresponding tangent vector.
func tangentVector(from cotangent: CotangentVector) -> TangentVector
}
```

Expand All @@ -141,20 +130,15 @@ Mathematically, `Differentiable` represents a [differentiable manifold]: this is
</p>

Here is a detailed explanation of the `Differentiable` protocol:
* `associatedtype TangentVector` represents the type of directional derivatives computed via forward-mode differentiation.
* `associatedtype CotangentVector` represents the type of gradient values computed via reverse-mode differentiation.
* `CotangentVector` types are used and produced by differential operators like `gradient` and `pullback`.
* `associatedtype TangentVector` represents the type of derivatives.
* `var allDifferentiableVariables: AllDifferentiableVariables` represents all differentiable variables in an instance of the conforming type, where `associatedtype AllDifferentiableVariables` is the type of all differentiable variables.
* The motivation/design behind "all differentiable variables" is enabling key-path-based parameter optimization by making parameters and their gradients have the same type. Read the [synthesis rules](#compiler-synthesized-implementations) below and the [parameter optimization document][parameter-optimization] for more information.
* `TangentVector`, `CotangentVector`, and `AllDifferentiableVariables` are closely related.
* `TangentVector` and `AllDifferentiableVariables` are closely related.
* All three associated types must themselves conform to `Differentiable`.
* The `Differentiable` protocol associated types of the associated types themselves are defined to be mathematically correct.
* `Foo.TangentVector.TangentVector` is `Foo.TangentVector` itself.
* `Foo.CotangentVector.TangentVector` is `Foo.CotangentVector` itself.
* `Foo.TangentVector.CotangentVector` is `Foo.CotangentVector`.
* `Foo.CotangentVector.CotangentVector` is `Foo.TangentVector`.
* `Foo.AllDifferentiableVariables` has the same `TangentVector` and `CotangentVector` as `Foo`.
* Additionally, `TangentVector` and `CotangentVector` must conform to `AdditiveArithmetic`, so that they can be zero-initialized and accumulated via addition. These are necessary to perform the chain rule of differentiation.
* `Foo.AllDifferentiableVariables` has the same `TangentVector` as `Foo`.
* Additionally, `TangentVector` must conform to `AdditiveArithmetic`, so that they can be zero-initialized and accumulated via addition. These are necessary to perform the chain rule of differentiation.
* Manifold operations.
* These currently involve `tangentVector(from:)` and `moved(along:)`. These operations can be useful for implementing manifold-related algorithms, like optimization on manifolds, but are not relevant for simple differentiation use cases.

Expand All @@ -163,15 +147,13 @@ The standard library defines conformances to the `Differentiable` protocol for `
```swift
extension Float: Differentiable {
public typealias TangentVector = Float
public typealias CotangentVector = Float
public typealias AllDifferentiableVariables = Float
}
// Conformances for `Double` and `Float80` are defined similarly.

// `Tensor` is defined in the TensorFlow library and represents a multidimensional array.
extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint {
public typealias TangentVector = Tensor
public typealias CotangentVector = Tensor
public typealias AllDifferentiableVariables = Tensor
}
```
Expand All @@ -190,16 +172,16 @@ The synthesis behavior is explained below.

### Associated type synthesis

Here are the synthesis rules for the three `Differentiable` associated types: `TangentVector`, `CotangentVector`, and `AllDifferentiableVariables`.
Here are the synthesis rules for the two `Differentiable` associated types: `TangentVector` and `AllDifferentiableVariables`.

Let "differentiation properties" refer to all stored properties of the conforming type that are not marked with `@noDerivative`. These stored properties are guaranteed by the synthesis condition to all conform to `Differentiable`.

The synthesis rules are:
* Set associated types to `Self`, if possible.
* If the conforming type conforms to `AdditiveArithmetic`, and no `@noDerivative` stored properties exist, and all stored properties satisfy `Self == Self.TangentVector == Self.CotangentVector == Self.AllDifferentiableVariables`, then all associated types can be set to typealiases of `Self`.
* Synthesize a single `AllDifferentiableVariables` member struct. Set `TangentVector` and `CotangentVector` to `AllDifferentiableVariables` if possible; otherwise synthesize more member structs.
* If the conforming type conforms to `AdditiveArithmetic`, and no `@noDerivative` stored properties exist, and all stored properties satisfy `Self == Self.TangentVector == Self.AllDifferentiableVariables`, then all associated types can be set to typealiases of `Self`.
* Synthesize a single `AllDifferentiableVariables` member struct. Set `TangentVector` to `AllDifferentiableVariables` if possible; otherwise synthesize more member structs.
* Regarding member struct synthesis: for each "differentiation property" in the conforming type, a corresponding stored property is synthesized in the member structs, with type equal to the property’s associated type.
* `TangentVector` and `CotangentVector` can be set to `AllDifferentiableVariables` if all differentiation properties conform to `AdditiveArithmetic` and satisfy `Self.TangentVector == Self.CotangentVector == Self.AllDifferentiableVariables`. This is useful because it prevents redundant struct synthesis. Also, this enables [key-path-based parameter optimization][parameter-optimization] because parameters and gradients have the same type.
* `TangentVector` can be set to `AllDifferentiableVariables` if all differentiation properties conform to `AdditiveArithmetic` and satisfy `Self.TangentVector == Self.AllDifferentiableVariables`. This is useful because it prevents redundant struct synthesis. Also, this enables [key-path-based parameter optimization][parameter-optimization] because parameters and gradients have the same type.

A memberwise initializer is synthesized for the conforming type itself, in addition to all associated structs. This is important for differentiating struct properties accesses and synthesizing manifold operation requirements.

Expand Down Expand Up @@ -229,21 +211,13 @@ Manifold operations are synthesized to forward the same operation defined on dif

```swift
// Let `Foo` be the name of the type conforming to `Differentiable`.
func tangentVector(from cotangent: CotangentVector) -> TangentVector {
return TangentVector(x: x.tangentVector(from: cotangent.x), ...)
}
func moved(along tangent: TangentVector) -> Foo {
return Foo(x: x.moved(along: tangent.x), ...)
Foo(x: x.moved(along: tangent.x), ...)
}

// Potential shortcuts for synthesis:
// When `TangentVector == CotangentVector`:
func tangentVector(from cotangent: CotangentVector) -> TangentVector {
return cotangent
}
// When `Foo == TangentVector`:
// Potential shortcut for synthesis, when `Foo == TangentVector`:
func moved(along tangent: TangentVector) -> Foo {
return tangent
self + tangent
}
```

Expand All @@ -266,11 +240,6 @@ struct GenericWrapper<T: Differentiable, U: Differentiable>: Differentiable {
// var y: U.TangentVector
// ...
// }
// struct CotangentVector: Differentiable, AdditiveArithmetic {
// var x: T.CotangentVector
// var y: U.CotangentVector
// ...
// }
// struct AllDifferentiableVariables: Differentiable {
// var x: T.AllDifferentiableVariables
// var y: U.AllDifferentiableVariables
Expand All @@ -280,10 +249,6 @@ struct GenericWrapper<T: Differentiable, U: Differentiable>: Differentiable {
// get { return AllDifferentiableVariables(x: x, y: y) }
// set { x = newValue.x; y = newValue.y }
// }
// func tangentVector(from cotangent: CotangentVector) -> TangentVector {
// return TangentVector(x: x.tangentVector(from: cotangent.x),
// y: y.tangentVector(from: cotangent.y))
// }
// func moved(along tangent: TangentVector) -> Foo {
// return GenericWrapper(x: x.moved(along: tangent.x)
// y: y.moved(along: tangent.y))
Expand Down