Skip to content

Commit

Permalink
docs: align a few misaligned code blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jan 9, 2024
1 parent 83366a5 commit 388315c
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions docs/src/rule_author/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,36 +58,36 @@ Read more about writing rules for constructors and callable objects [here](@ref
The `rrule` returns the primal result `y`, and the pullback function.
It is a _very_ good idea to name your pullback function, so that they are helpful when appearing in the stacktrace.
```julia
y = foo_mul(foo, b)
y = foo_mul(foo, b)
```
Computes the primal result.
It is possible to change the primal computation so that work can be shared between the primal and the pullback.
See e.g. [the rule for `sort`](https://github.com/JuliaDiff/ChainRules.jl/blob/a75193768775975fac5578c89d1e5f50d7f358c2/src/rulesets/Base/sort.jl#L19-L35), where the sorting is done only once.
```julia
function foo_mul_pullback(ȳ)
...
return f̄, f̄oo, b̄
end
function foo_mul_pullback(ȳ)
...
return f̄, f̄oo, b̄
end
```
The pullback function takes in the tangent of the primal output (``) and returns the tangents of the primal inputs.
Note that it returns a tangent for the primal function in addition to the tangents of primal arguments.

Finally, computing the tangents of primal inputs:
```julia
= NoTangent()
= NoTangent()
```
The function `foo_mul` has no fields (i.e. it is not a closure) and can not be perturbed.
Therefore its tangent (``) is a `NoTangent`.
```julia
f̄oo = Tangent{Foo}(; A=* b', c=ZeroTangent())
f̄oo = Tangent{Foo}(; A=* b', c=ZeroTangent())
```
The struct `foo::Foo` gets a `Tangent{Foo}` structural tangent, which stores the tangents of fields of `foo`.

The tangent of the field `A` is `ȳ * b'`,

The tangent of the field `c` is `ZeroTangent()`, because `c` can be perturbed but has no effect on the primal output.
```julia
= @thunk(foo.A' * ȳ)
= @thunk(foo.A' * ȳ)
```
The tangent of `b` is `foo.A' * ȳ`, but we have wrapped it into a `Thunk`, a tangent type that represents delayed computation.
The idea is that in case the tangent is not used anywhere, the computation never happens.
Expand Down

0 comments on commit 388315c

Please sign in to comment.