Skip to content

Make all pullbacks only take a single input #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.7.5"
version = "0.8.0"

[deps]
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
Expand Down
22 changes: 12 additions & 10 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,26 @@ This document will explain this point of view in some detail.
##### Some terminology/conventions.

Let ``p`` be an element of type M, which is defined by some assignment of numbers ``x_1,...,x_m``,
say ``(x_1,...,x_m) = (a_1,...,1_m)``
say ``(x_1,...,x_m) = (a_1,...,1_m)``

A _function_ ``f:M \to K`` on ``M`` is (for simplicity) a polynomial ``K[x_1, ... x_m]``

The tangent space ``T_pM`` of ``T`` at point ``p`` is the ``K``-vector space spanned by derivations ``d/dx``.
The tangent space acts linearly on the space of functions. They act as usual on functions. Our starting point is
The tangent space ``T_pM`` of ``T`` at point ``p`` is the ``K``-vector space spanned by derivations ``d/dx``.
The tangent space acts linearly on the space of functions. They act as usual on functions. Our starting point is
that we know how to write down ``d/dx(f) = df/dx``.

The collection of tangent spaces ``{T_pM}`` for ``p\in M`` is called the _tangent bundle_ of ``M``.

Let ``df`` denote the first order information of ``f`` at each point. This is called the differential of ``f``.
Let ``df`` denote the first order information of ``f`` at each point. This is called the differential of ``f``.
If the derivatives of ``f`` and ``g`` agree at ``p``, we say that ``df`` and ``dg`` represent the same cotangent at ``p``.
The covectors ``dx_1, ..., dx_m`` form the basis of the cotangent space ``T^*_pM`` at ``p``. Notice that this vector space is
The covectors ``dx_1, ..., dx_m`` form the basis of the cotangent space ``T^*_pM`` at ``p``. Notice that this vector space is
dual to ``T_p``

The collection of cotangent spaces ``{T^*_pM}`` for ``p\in M`` is called the _cotangent bundle_ of ``M``.

##### Push-forwards and pullbacks

Let ``N`` be another type, defined by numbers ``y_1,...,y_n``, and let ``g:M \to N`` be a _map_, that is,
Let ``N`` be another type, defined by numbers ``y_1,...,y_n``, and let ``g:M \to N`` be a _map_, that is,
an ``n``-dimensional vector ``(g_1, ..., g_m)`` of functions on ``M``.

We define the _push-forward_ ``g_*:TM \to TN`` between tangent bundles by ``g_*(X)(h) = X(g\circ h)`` for any tangent vector ``X`` and function ``f``.
Expand All @@ -154,7 +154,7 @@ We have ``g_*(d/dx_i)(y_j) = dg_j/dx_i``, so the push-forward corresponds to the
Similarly, the pullback of the differential ``df`` is defined by
``g^*(df) = d(f\circ g)``. So for a coordinate differential ``dy_j``, we have
``g^*(dy_j) = d(g_j)``. Notice that this is a covector, and we could have defined the pullback by its action on vectors by
``g^*(dh)(X) = g_*(X)(dh) = X(g\circ h)`` for any function ``f`` on ``N`` and ``X\in TM``. In particular,
``g^*(dh)(X) = g_*(X)(dh) = X(g\circ h)`` for any function ``f`` on ``N`` and ``X\in TM``. In particular,
``g^*(dy_j)(d/dx_i) = d(g_j)/dx_i``. If you work out the action in a basis of the cotangent space, you see that it acts
by the adjoint of the Jacobian.

Expand All @@ -170,13 +170,13 @@ But pulling back gradients still should not be a thing.

If the goal is to evaluate the gradient of a function ``f=g\circ h:M \to N \to K``, where ``g`` is a map and ``h`` is a function,
we have two obvious options:
First, we may push-forward a basis of ``M`` to ``TK`` which we identify with K itself.
First, we may push-forward a basis of ``M`` to ``TK`` which we identify with K itself.
This results in ``m`` scalars, representing components of the gradient.
Step-by-step in coordinates:
1. Compute the push-forward of the basis of ``T_pM``, i.e. just the columns of the Jacobian ``dg_i/dx_j``.
2. Compute the push-forward of the function ``h`` (consider it as a map, K is also a manifold!) to get ``h_*(g_*T_pM) = \sum_j dh/dy_i (dg_i/dx_j)``

Second, we pull back the differential ``dh``:
Second, we pull back the differential ``dh``:
1. compute ``dh = dh/dy_1,...,dh/dy_n`` in coordinates.
2. pull back by (in coordinates) multiplying with the adjoint of the Jacobian, resulting in ``g_*(dh) = \sum_i(dg_i/dx_j)(dh/dy_i)``.

Expand Down Expand Up @@ -263,12 +263,14 @@ Similarly every `pullback` returns an extra `∂self`, which for things without
- **Pullback**
- returned by `rrule`
- takes output space wobbles, gives input space wiggles
- 1 argument per original function return
- Argument structure matches structure of primal function output
- If primal function returns a tuple, then pullback takes in a tuple of differentials.
- 1 return per original function argument + 1 for the function itself

- **Pushforward:**
- part of `frule`
- takes input space wiggles, gives output space wobbles
- Argument structure matches primal function argument structure, but passed as a tuple at start of `frule`
- 1 argument per original function argument + 1 for the function itself
- 1 return per original function return

Expand Down
6 changes: 4 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ methods for `frule` and `rrule`:
function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
Ω = f(x₁, x₂, ...)
\$(statement₁, statement₂, ...)
return Ω, (ΔΩ₁, ΔΩ₂, ...) -> (
return Ω, ((ΔΩ₁, ΔΩ₂, ...)) -> (
NO_FIELDS,
∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
Expand Down Expand Up @@ -185,8 +185,10 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
propagation_expr(Δs, ∂s)
end

# Multi-output functions have pullbacks with a tuple input that will be destructured
pullback_input = n_outputs == 1 ? first(Δs) : Expr(:tuple, Δs...)
pullback = quote
function $(propagator_name(f, :pullback))($(Δs...))
function $(propagator_name(f, :pullback))($pullback_input)
return (NO_FIELDS, $(pullback_returns...))
end
end
Expand Down
14 changes: 14 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,17 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@inferred frule((Zero(), sx, sy), very_nice, 1, 2)
end
end


simo(x) = (x, 2x)
@scalar_rule(simo(x), 1, 2)

@testset "@scalar_rule with multiple inputs" begin
y, simo_pb = rrule(simo, π)

@test simo_pb((10, 20)) == (NO_FIELDS, 50)

y, ẏ = frule((NO_FIELDS, 50), simo, π)
@test y == (π, 2π)
@test ẏ == (50, 100)
end