Skip to content

Commit

Permalink
tweak text
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed May 14, 2024
1 parent 3951651 commit 5ba983c
Showing 1 changed file with 41 additions and 48 deletions.
89 changes: 41 additions & 48 deletions docs/src/tutorials/gradient_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ also known as reverse-mode automatic differentiation.
Given a model, some data, and a loss function, this answers the question
"what direction, in the space of the model's parameters, reduces the loss fastest?"

### `gradient(f, x)` interface

Julia's ecosystem has many versions of `gradient(f, x)`, which evaluates `y = f(x)` then retuns `∂y_∂x`. The details of how they do this vary, but the interfece is similar. An incomplete list is (alphabetically):

```julia
Expand Down Expand Up @@ -37,9 +39,24 @@ julia> Zygote.withgradient(x -> sum(sqrt, x), [1 4 16.])
(val = 7.0, grad = ([0.5 0.25 0.125],))
```

These all show the same `∂y_∂x` with respect to `x::Vector`. Sometimes, the result is within a tuple or a NamedTuple.
These all show the same `∂y_∂x` with respect to `x::Vector`. Sometimes, the result is within a tuple or a NamedTuple, containing `y` as well as the gradient.

Note that in all cases, only code executed within the call to `gradient` is differentiated. Calculating the objective function before calling `gradient` will not work, as all information about the steps from `x` to `y` has been lost. For example:

```julia
julia> y = sum(sqrt, x) # calculate the forward pass alone
7.0

julia> y isa Float64 # has forgotten about sqrt and sum
true

julia> Zygote.gradient(x -> y, x) # this cannot work, and gives zero
(nothing,)
```

### `gradient(f, model)` for Flux models

However, the parameters of a Flux model are encapsulated inside the various layers. The model is a set of nested structures. And the gradients `∂loss_∂model` which Flux uses are similarly nested objects.
However, the parameters of a Flux model are encapsulated inside the various layers. The model is a set of nested structures, and the gradients `∂loss_∂model` which Flux uses are similarly nested objects.
For example, let's set up a simple model & loss:

```julia
Expand Down Expand Up @@ -83,66 +100,45 @@ Chain(
While the type returned for `∂loss_∂model` varies, they all have the same nested structure, matching that of the model. This is all that Flux needs.

```julia
julia> grads_z[1].layers[1].weight
julia> grads_z[1].layers[1].weight # get the weight matrix
2×3 Matrix{Float64}:
-0.181715 0.0 0.0
0.181715 0.0 0.0

julia> grad_e.layers[1].weight
julia> grad_e.layers[1].weight # get the corresponding gradient matrix
2×3 Matrix{Float64}:
-0.181715 0.0 0.0
0.181715 0.0 0.0
```

Here's Flux updating the model using each gradient:
<!--- perhaps we should trim this?? --->

```julia
julia> opt = Flux.setup(Descent(1/3), model)
(layers = ((weight = Leaf(Descent(0.333333), nothing),), ()),)

julia> Flux.update!(opt, deepcopy(model), grads_t[1])[2][1].weight
2×3 Matrix{Float64}:
1.06057 3.0 5.0
1.93943 4.0 6.0
julia> model_z = deepcopy(model);

julia> Flux.update!(opt, deepcopy(model), grads_z[1])[2][1].weight
2×3 Matrix{Float64}:
1.06057 3.0 5.0
1.93943 4.0 6.0
julia> Flux.update!(opt, model_z, grads_z[1]);

julia> Flux.update!(opt, deepcopy(model), grads_d[1])[2][1].weight
julia> model_z.layers[1].weight # updated weight matrix
2×3 Matrix{Float64}:
1.06057 3.0 5.0
1.93943 4.0 6.0

julia> Flux.update!(opt, deepcopy(model), grad_e)[2][1].weight
julia> model_e = deepcopy(model);

julia> Flux.update!(opt, model_e, grad_e)[2][1].weight # same update
2×3 Matrix{Float64}:
1.06057 3.0 5.0
1.93943 4.0 6.0
```

In this case they are all identical, but there are some caveats, explored below.


Aside, Tapir seems not to work just yet?
```julia
julia> Tapir_grad(f, xs...) = Tapir.value_and_pullback!!(Tapir.build_rrule(f, xs...), 1.0, f, xs...);

julia> _, grad_p = Tapir_grad(loss, model)
(0.6067761335170363, (NoTangent(), Tangent{@NamedTuple{layers::Tuple{Tangent{@NamedTuple{weight::Matrix{Float64}}}, NoTangent}}}((layers = (Tangent{@NamedTuple{weight::Matrix{Float64}}}((weight = [0.0 0.0 0.0; 0.0 0.0 0.0],)), NoTangent()),))))

julia> grad_p.fields.layers[1].fields.weight
2×3 Matrix{Float64}:
0.0 0.0 0.0
0.0 0.0 0.0
```

<!--- I made an issue... perhaps fixed now?? --->

<hr/>

## Packages
## Automatic Differentiation Packages

Both Zygote and Tracker were written for Flux, and at present, Flux loads Zygote and exports `Zygote.gradient`, and calls this within `Flux.train!`. But apart from that, there is very little coupling between Flux and the automatic differentiation package.

Expand All @@ -163,24 +159,21 @@ Source-to-source, within Julia.
* Returns nested NamedTuples and Tuples, and uses `nothing` to mean zero.


### Zygote, implicit mode

Flux's default used to be work like this, instead of using deeply nested trees for gradients as above:
!!! compat "Deprecated: Zygote's implicit mode"
Flux's default used to be work like this, instead of using deeply nested trees for gradients as above:
```julia
julia> ps = Flux.params(model) # dictionary-like object, with global `objectid` refs
Params([Float32[1.0 3.0 5.0; 2.0 4.0 6.0]])

```julia
julia> ps = Flux.params(model)
Params([Float32[1.0 3.0 5.0; 2.0 4.0 6.0]])

julia> val, grad = Zygote.withgradient(() -> loss(model), ps)
(val = 0.6067761f0, grad = Grads(...))

julia> grad[model.layers[1].weight] # dictionary, indexed by parameter arrays
2×3 Matrix{Float32}:
0.0 0.0 -0.181715
0.0 0.0 0.181715
```
julia> val, grad = Zygote.withgradient(() -> loss(model), ps)
(val = 0.6067761f0, grad = Grads(...))

The code inside Zygote is much the same -- do not expect large changes in speed, nor any changes in what works and what does not.
julia> grad[model.layers[1].weight] # another dictionary, indexed by parameter arrays
2×3 Matrix{Float32}:
0.0 0.0 -0.181715
0.0 0.0 0.181715
```
The code inside Zygote is much the same -- do not expect large changes in speed, nor any changes in what works and what does not.

### [Tracker.jl](https://github.com/FluxML/Tracker.jl)

Expand Down

0 comments on commit 5ba983c

Please sign in to comment.