Skip to content

Commit

Permalink
Ton of doctests added (#1194)
Browse files Browse the repository at this point in the history
* Ton of doctests added to index.md

* Ton of doctests added to index.md

* Ton of doctests added to index.md

* Ton of doctests added to index.md

* Ton of doctests added to index.md

* Ton of doctests added to index.md

* outdated example fixed

* outdated example fixed

* outdated example fixed

* outdated example fixed

* doctests added to adjoints.md

* doctests added to adjoints.md

* doctests added to adjoints.md

* doctests added to adjoints.md

* outdated example updated

* More doctests added

* More doctests added

* doctest added and checked properly
  • Loading branch information
arcAman07 authored Apr 3, 2022
1 parent a133200 commit 3928ab9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Documenter, Zygote

makedocs(
sitename="Zygote",
doctest = true,
pages = [
"Home" => "index.md",
"Custom Adjoints" => "adjoints.md",
Expand Down
28 changes: 15 additions & 13 deletions docs/src/adjoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ The `@adjoint` macro is an important part of Zygote's interface; customising you

`gradient` is really just syntactic sugar around the more fundamental function `pullback`.

```julia
```jldoctest adjoints
julia> using Zygote
julia> y, back = Zygote.pullback(sin, 0.5);
julia> y
Expand Down Expand Up @@ -55,7 +57,7 @@ julia> cos(0.5)

More generally

```julia
```jldoctest adjoints
julia> function mygradient(f, x...)
_, back = Zygote.pullback(f, x...)
back(1)
Expand All @@ -76,15 +78,15 @@ Zygote has many adjoints for non-mathematical operations such as for indexing an

We can extend Zygote to a new function with the `@adjoint` function.

```julia
julia> mul(a, b) = a*b
```jldoctest adjoints
julia> mul(a, b) = a*b;
julia> using Zygote: @adjoint
julia> @adjoint mul(a, b) = mul(a, b), c̄ -> (c̄*b, c̄*a)
julia> gradient(mul, 2, 3)
(3, 2)
(3.0, 2.0)
```

It might look strange that we write `mul(a, b)` twice here. In this case we want to call the normal `mul` function for the pullback pass, but you may also want to modify the pullback pass (for example, to capture intermediate results in the pullback).
Expand Down Expand Up @@ -152,7 +154,7 @@ We usually use custom adjoints to add gradients that Zygote can't derive itself

### Gradient Hooks

```julia
```jldoctest adjoints
julia> hook(f, x) = x
hook (generic function with 1 method)
Expand All @@ -161,17 +163,17 @@ julia> @adjoint hook(f, x) = x, x̄ -> (nothing, f(x̄))

`hook` doesn't seem that interesting, as it doesn't do anything. But the fun part is in the adjoint; it's allowing us to apply a function `f` to the gradient of `x`.

```julia
```jldoctest adjoints
julia> gradient((a, b) -> hook(-, a)*b, 2, 3)
(-3, 2)
(-3.0, 2.0)
```

We could use this for debugging or modifying gradients (e.g. gradient clipping).

```julia
```jldoctest adjoints
julia> gradient((a, b) -> hook(ā -> @show(ā), a)*b, 2, 3)
ā = 3
(3, 2)
ā = 3.0
(3.0, 2.0)
```

Zygote provides both `hook` and `@showgrad` so you don't have to write these yourself.
Expand All @@ -180,7 +182,7 @@ Zygote provides both `hook` and `@showgrad` so you don't have to write these you

A more advanced example is checkpointing, in which we save memory by re-computing the pullback pass of a function during the backwards pass. To wit:

```julia
```jldoctest adjoints
julia> checkpoint(f, x) = f(x)
checkpoint (generic function with 1 method)
Expand All @@ -192,7 +194,7 @@ julia> gradient(x -> checkpoint(sin, x), 1)

If a function has side effects we'll see that the pullback pass happens twice, as expected.

```julia
```jldoctest adjoints
julia> foo(x) = (println(x); sin(x))
foo (generic function with 1 method)
Expand Down
12 changes: 7 additions & 5 deletions docs/src/complex.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,32 @@ Complex numbers add some difficulty to the idea of a "gradient". To talk about `

If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i`` (note that ``\bar c`` means gradient, and ``c'`` means conjugate). It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box.

```julia
```jldoctest complex
julia> using Zygote
julia> gradient(c -> abs2(c), 1+2im)
(2 + 4im,)
(2.0 + 4.0im,)
```

However, while this is a very pragmatic definition that works great for gradient descent, it's not quite aligned with the mathematical notion of the derivative: i.e. ``f(c + \epsilon) \approx f(c) + \bar c \epsilon``. In general, such a ``\bar c`` is not possible for complex numbers except when `f` is *holomorphic* (or *analytic*). Roughly speaking this means that the function is defined over `c` as if it were a normal real number, without exploiting its complex structure – it can't use `real`, `imag`, `conj`, or anything that depends on these like `abs2` (`abs2(x) = x*x'`). (This constraint also means there's no overlap with the Real case above; holomorphic functions always return complex numbers for complex input.) But most "normal" numerical functions – `exp`, `log`, anything that can be represented by a Taylor series – are fine.

Fortunately it's also possible to get these derivatives; they are the conjugate of the gradients for the real part.

```julia
```jldoctest complex
julia> gradient(x -> real(log(x)), 1+2im)[1] |> conj
0.2 - 0.4im
```

We can check that this function is holomorphic – and thus that the gradient we got out is sensible – by checking the Cauchy-Riemann equations. In other words this should give the same answer:

```julia
```jldoctest complex
julia> -im*gradient(x -> imag(log(x)), 1+2im)[1] |> conj
0.2 - 0.4im
```

Notice that this fails in a non-holomorphic case, `f(x) = log(x')`:

```julia
```jldoctest complex
julia> gradient(x -> real(log(x')), 1+2im)[1] |> conj
0.2 - 0.4im
Expand Down
20 changes: 10 additions & 10 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ Zygote is easy to understand since, at its core, it has a one-function API (`pul

`gradient` calculates derivatives. For example, the derivative of ``3x^2 + 2x + 1`` is ``6x + 2``, so when `x = 5`, `dx = 32`.

```julia
```jldoctest index
julia> using Zygote
julia> gradient(x -> 3x^2 + 2x + 1, 5)
(32,)
(32.0,)
```

`gradient` returns a tuple, with a gradient for each argument to the function.

```julia
```jldoctest index
julia> gradient((a, b) -> a*b, 2, 3)
(3, 2)
(3.0, 2.0)
```

This will work equally well if the arguments are arrays, structs, or any other Julia type, but the function should return a scalar (like a loss or objective ``l``, if you're doing optimisation / ML).
Expand All @@ -48,7 +48,7 @@ julia> gradient(x -> 3x^2 + 2x + 1, 1//4)

Control flow is fully supported, including recursion.

```julia
```jldoctest index
julia> function pow(x, n)
r = 1
for i = 1:n
Expand All @@ -59,26 +59,26 @@ julia> function pow(x, n)
pow (generic function with 1 method)
julia> gradient(x -> pow(x, 3), 5)
(75,)
(75.0,)
julia> pow2(x, n) = n <= 0 ? 1 : x*pow2(x, n-1)
pow2 (generic function with 1 method)
julia> gradient(x -> pow2(x, 3), 5)
(75,)
(75.0,)
```

Data structures are also supported, including mutable ones like dictionaries. Arrays are currently immutable, though [this may change](https://github.com/FluxML/Zygote.jl/pull/75) in future.

```julia
```jldoctest index
julia> d = Dict()
Dict{Any,Any} with 0 entries
Dict{Any, Any}()
julia> gradient(5) do x
d[:x] = x
d[:x] * d[:x]
end
(10,)
(10.0,)
julia> d[:x]
5
Expand Down

0 comments on commit 3928ab9

Please sign in to comment.