Skip to content

Commit

Permalink
Merge branch 'master' into mz/deprecate_nograd
Browse files Browse the repository at this point in the history
  • Loading branch information
mzgubic authored Jul 7, 2022
2 parents 4bb6b4d + 6336b60 commit af434d6
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 79 deletions.
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ using Documenter, Zygote

makedocs(
sitename="Zygote",
doctest = true,
doctest = false,
pages = [
"Home" => "index.md",
"Limitations" => "limitations.md",
"Custom Adjoints" => "adjoints.md",
"Utilities" => "utils.md",
"Complex Differentiation" => "complex.md",
Expand Down
150 changes: 150 additions & 0 deletions docs/src/limitations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Limitations

Zygote aims to support differentiating any code you might write in Julia, but it still has a few limitations. Notably, you might encounter errors when trying to differentiate:
- array mutation
- `try`/`catch` statements
- "foreign call" expressions

In this section, we will introduce examples where each of these errors occurs as well as possible work-arounds.

## Array mutation

Array mutation is by far the most commonly encountered Zygote limitation.

Automatic differentiation (AD) systems like Zygote are built on basic principles of calculus where we encounter _pure_ functions. This means that the function, ``y = f(x)``, does not modify ``x`` and only produces the output ``y`` based on ``x``. If we have a chain of functions, such as ``y = h(g(f(x)))``, we can apply the chain rule to differentiate it. AD systems are built to programmatically apply the chain rule to a series of function calls. Unfortunately, typical programs do not behave this way. We might allocate some memory, `x`, then call a function `y = f!(x)` that modifies `x` to produce the output `y`. This mutating behavior is a _side-effect_ of `f!`. Side-effects are difficult for AD systems to handle, because the must track changes to mutated variables and store older versions of the variable. For these reasons, Zygote does not handle array mutation for now.

Let's explore this with a more concrete example. Here we define a simple mutating function, `f!`, which modifies the elements of its input argument, `x`, in place.
```julia
function f!(x)
x .= 2 .* x

return x
end
```
Let's see what happens when we differentiate `f!`
```julia
julia> gradient(rand(3)) do x
sum(f!(x))
end
ERROR: Mutating arrays is not supported -- called copyto!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in-place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation

Stacktrace:
...
```
We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling `copyto!` (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes `x .= ...` which is given as an example of array mutation. Other examples of mutating operations include:
- setting values (`x .= ...`)
- appending/popping values (`push!(x, v)` / `pop!(x)`)
- calling mutating functions (`mul!(C, A, B)`)

!!! warning

Non-mutating functions may also use mutation under the hood. This can be done for performance reasons or code re-use.

```julia
function g!(x, y)
x .= 2 .* y

return x
end
g(y) = g!(similar(y), y)
```
Here `g` is a "non-mutating function," and it indeed does not mutate `y`, its only argument. But it still allocates a new array and calls `g!` on this array which will result in a mutating operation. You may encounter such functions when working with another package.

Specifically for array mutation, we can use [`Zygote.Buffer`](@ref) to re-write our function. For example, let's fix the function `g!` above.
```julia
function g!(x, y)
x .= 2 .* y

return x
end

function g(y)
x = Zygote.Buffer(y) # Buffer supports syntax like similar
g!(x, y)
return copy(x) # this step makes the Buffer immutable (w/o actually copying)
end

julia> gradient(rand(3)) do y
sum(g(y))
end
([2.0, 2.0, 2.0],)
```

## Try-catch statements

Any expressions involving `try`/`catch` statements is not supported.
```julia
function tryme(x)
try
2 * x
catch e
throw(e)
end
end

julia> gradient(rand(3)) do x
sum(tryme(x))
end
ERROR: Compiling Tuple{typeof(tryme), Vector{Float64}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/dev/limitations.html#try-catch-statements-1

Stacktrace:
...
```
Here `tryme` uses a `try`/`catch` statement, and Zygote throws an error when trying to differentiate it as expected. `try`/`catch` expressions are used for error handling, but they are less common in Julia compared to some other languages.

## Foreign call expressions

Foreign call expressions refer to expressions that call external libraries such as code written in C or Fortran. You may want to read more about these calls in the [Julia documentation](https://docs.julialang.org/en/v1/manual/calling-c-and-fortran-code/). Scientific computing libraries in Julia may call established C or Fortran libraries under the hood. Since the underlying code for a foreign call expression is not in Julia, it is not possible for Zygote to differentiate this expression.

Below, we define a function that calls a standard C function, `clock`. This function returns the Unix clock as an `Int32`.
```julia
julia> jclock(x) = ccall(:clock, Int32, ()) * 2
jclock (generic function with 1 method)

julia> jclock(2)
30921278

julia> gradient(jclock, rand())
ERROR: Can't differentiate foreigncall expression
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html

Stacktrace:
...
```
`jclock` will multiply the result of our C function by an argument. When we try to differentiate with respect to this argument, we get an `foreigncall` error.

## Solutions

For all of the errors above, the suggested solutions are similar. You have the following possible work arounds available (in order of preference):
1. avoid the error-inducing operation (e.g. do not use mutating functions)
2. define a [custom `ChainRulesCore.rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html)
3. open an [issue on Zygote](https://github.com/FluxML/Zygote.jl/issues)

Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value. Recall that array mutation can also be avoided by using [`Zygote.Buffer`](@ref) as discussed above.

Sometimes, we cannot avoid expressions that Zygote cannot differentiate, but we may be able to manually derive a gradient. In these cases, you can write [a custom `rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) using ChainRules.jl. Please refer to the linked ChainRules documentation for how to do this. _This solution is the only solution available for foreign call expressions._ Below, we provide a custom `rrule` for `jclock`.
```julia
jclock(x) = ccall(:clock, Int32, ()) * x

function ChainRulesCore.rrule(::typeof(jclock), x)
y = jclock(x)
pb(ȳ) = (ChainRulesCore.NoTangent(), ȳ * y)

return y, pb
end

julia> gradient(jclock, rand())
(674298.4243400148,)
```

Lastly, if the code causing problems can be fixed, but it is package code instead of your code, then you should open an issue. For functions built into Julia or its standard libraries, you can open an issue with Zygote.jl or ChainRules.jl. For functions in other packages, you can open an issue with the corresponding package issue tracker.
11 changes: 9 additions & 2 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ function instrument(ir::IR)
if isexpr(ex, :foreigncall, :isdefined)
continue
elseif isexpr(ex, :enter, :leave)
error("try/catch is not supported.")
error("""try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/dev/limitations.html#Try-catch-statements-1
""")
elseif isexpr(ex, :(=))
@assert ex.args[1] isa GlobalRef
pr[v] = xcall(Zygote, :global_set, QuoteNode(ex.args[1]), ex.args[2])
Expand Down Expand Up @@ -277,7 +280,11 @@ function adjoint(pr::Primal)
grads[ex.val] = grads[v]
elseif isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo)
elseif isexpr(ex)
push!(rb, stmt(xcall(Base, :error, "Can't differentiate $(ex.head) expression"),
push!(rb, stmt(xcall(Base, :error, """
Can't differentiate $(ex.head) expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html
"""),
line = b[v].line))
else # A literal value
continue
Expand Down
90 changes: 14 additions & 76 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,26 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra

@adjoint getindex(::Type{T}, xs...) where {T} = T[xs...], dy -> (nothing, dy...)

_throw_mutation_error(f, args...) = error("""
Mutating arrays is not supported -- called $f($(join(map(typeof, args), ", ")), ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation-1
""")

@adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...),
_ -> error("Mutating arrays is not supported -- called setindex!(::$(typeof(xs)), _...)")
_ -> _throw_mutation_error(setindex!, xs)

@adjoint! copyto!(xs, args...) = copyto!(xs, args...),
_ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(xs)), _...)")
_ -> _throw_mutation_error(copyto!, xs)

for f in [push!, pop!, pushfirst!, popfirst!]
@eval @adjoint! $f(x::AbstractVector, ys...) = $f(x, ys...),
_ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(x)), _...)")
_ -> _throw_mutation_error($f, x)
end

# General
Expand Down Expand Up @@ -306,88 +317,15 @@ end
sum(xs, dims = dims), Δ -> (nothing,)
end


function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
y, back = pullback(cx, ((f, xs) -> prod(f.(xs))), f, xs)
y, ȳ -> (nothing, back(ȳ)...)
end

@adjoint function maximum(xs::AbstractArray; dims = :)
max, i = findmax(xs, dims = dims)
max, function (Δ)
Δ isa Real && abs(Δ) <= sqrt(eps(float(Δ))) && return nothing
Δ′ = zero(xs)
Δ′[i] = Δ
return (Δ′,)
end
end

@adjoint function minimum(xs::AbstractArray; dims = :)
min, i = findmin(xs, dims = dims)
min, function (Δ)
Δ′ = zero(xs)
Δ′[i] = Δ
return (Δ′,)
end
end

@adjoint function dropdims(xs::AbstractArray; dims)
dropdims(xs, dims = dims), Δ -> (reshape(Δ, size(xs)...),)
end

@adjoint real(x::AbstractArray) = real(x), r̄ -> (real(r̄),)
@adjoint conj(x::AbstractArray) = conj(x), r̄ -> (conj(r̄),)
@adjoint imag(x::AbstractArray) = imag(x), ī -> (complex.(0, real.(ī)),)

@adjoint function mean(xs::AbstractArray; dims = :)
return mean(xs, dims=dims), Δ -> (_backmean(xs,Δ,dims),)
end
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(xs,i),*,dims)

@adjoint function Statistics.var(xs::AbstractArray; corrected::Bool=true, dims=:, mean=mean(xs, dims=dims))
return Statistics.var(xs; corrected=corrected, mean=mean, dims=dims), Δ -> _backvar(xs, Δ, corrected, mean, dims)
end
_backvar(xs, Δ, corrected::Bool, mean, dims) = _backvar(xs, Δ, mapreduce(i -> size(xs,i),*,dims) - corrected, mean)
_backvar(xs, Δ, corrected::Bool, mean, ::Colon) = _backvar(xs, Δ, length(xs) - corrected, mean)
_backvar(xs, Δ, N::Int, mean) = (convert(eltype(xs), 2/N) .* Δ .* (xs .- mean),)

@adjoint function Statistics.std(xs::AbstractArray; corrected::Bool=true, dims=:, mean=mean(xs, dims=dims))
s = Statistics.std(xs; corrected=corrected, mean=mean, dims=dims)
return s, Δ -> _backvar(xs, Δ ./ (2 .* s), corrected, mean, dims)
end

@adjoint function cumsum(xs::AbstractVector; dims::Integer = 1)
dims == 1 || return copy(xs), Δ -> (Δ,)
cumsum(xs), Δ -> (reverse(cumsum(reverse(Δ))),)
end
@adjoint function cumsum(xs::AbstractArray; dims::Integer)
dims <= ndims(xs) || return copy(xs), Δ -> (Δ,)
cumsum(xs; dims=dims), Δ -> begin
(reverse(cumsum(reverse(Δ, dims=dims), dims=dims), dims=dims),)
end
end

@adjoint eachrow(x::AbstractVecOrMat) = collect(eachrow(x)), dys -> ∇eachslice(dys, x, 1)
@adjoint eachcol(x::AbstractVecOrMat) = collect(eachcol(x)), dys -> ∇eachslice(dys, x, 2)
@adjoint eachslice(x::AbstractArray; dims::Integer) =
collect(eachslice(x; dims=dims)), dys -> ∇eachslice(dys, x, dims)

function ∇eachslice(dys, x::AbstractArray, dim::Integer) where {TX}
i1 = findfirst(dy -> dy isa AbstractArray, dys)
i1 === nothing && return (zero(x),) # all slices get nothing
T = promote_type(eltype(dys[i1]), eltype(x))
dx = similar(x, T)
for i in axes(x, dim)
if dys[i] isa AbstractArray
copyto!(selectdim(dx,dim,i), dys[i])
else
selectdim(dx,dim,i) .= 0
end
end
(dx,)
end


# LinearAlgebra
# =============
Expand Down
6 changes: 6 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,12 @@ end
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))

@test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9]

# issue 1224, second order
f1244(w, x) = sum(maximum((w * x).^2, dims=1))
g1244(w, x) = sum(gradient(f1244, w, x)[2].^2)
h1244(w, x) = gradient(g1244, w, x)[2]
@test h1244([1 2 3; 4 5 6.0], [7,8,9.0]) [300608, 375760, 450912]
end

@testset "minimum" begin
Expand Down

0 comments on commit af434d6

Please sign in to comment.