Description
Consider this function:
sq(x) = x==1 ? one(x) : x^2
@test FiniteDifferences.central_fdm(5, 1)(sq, 1) ≈ 2.0
@test_broken ForwardDiff.derivative(sq, 1.0) == 2.0
Here ForwardDiff
gets the wrong answer, according to your first calculus class: The derivative is defined by taking limits, evaluating sq(x + ε)
for some small ε
, and these always see the continuum x^2
not the special point.
One to think about this is to say that x==1
really means abs(x-1) < ζ
with some tiny ζ
, which we keep finite until we are sure we aren't confused. The calculus class assumption is that ζ << ε
.
The assumption of ForwardDiff
is the opposite. Its Dual(x,1)
encodes a perturbation x + 1ε
with ε
smaller than everything else around, and in particular ε << ζ
. Or in other words, sq
is viewed as being piecewise continuous, with a small flat area of width 2ζ
, which is still large enough for us to see that its slope is zero.
Of course nobody really writes contrived examples like sq
. But they do write things like this:
function prod1(xs::Vector)
p = one(eltype(xs))
for x in xs
p = p * x
end
p
end
function prod2(xs::Vector)
p = one(eltype(xs))
for x in xs
p = p * x
p == 0 && break # exit early once you know the answer
end
p
end
@test ForwardDiff.gradient(prod1, [1,2,0,4,0,6]) == zeros(6)
@test_broken ForwardDiff.gradient(prod2, [1,2,0,4,0,6]) == zeros(6)
This has almost the same problem as #197, where det(A)
tests for istriu(A) || istril(A)
before calling a simpler branch. The fact that f(x,y) == g(x,y)
when y==0
does not imply that df/dy == dg/dy
. So it seems AD ought not to take that branch.
In which case we want something like this:
Base.:(==)(x::Dual, y::Int) = x.value == y && iszero(x.partials)
Base.:(!=)(x::Dual, y::Int) = x.value != y || !iszero(x.partials)
This fixes the tests above, and (a slightly more careful version) fixes #197 and #407.
However, it means that fun(Dual(x,1)).value
need not be equal to fun(x)
, on a discontinuous function. Although fun(Dual(x,0)).value
should still be equal, @assert zero(x) == 0
isn't broken, and there should be no problems where functions use things like zero(eltype(xs))
for type-stability.
The idea that the forward evaluation is unchanged is often thought of as an axiom of AD, but for discontinuous functions, I think that's another way of saying ε << ζ
. Which is a choice. And one that your calculus teacher would disapprove of. The point of evaluating a function with dual numbers is, presumably, to find derivatives, so finding them correctly ought to have a higher priority.
There are other comparisons to think about, for example:
sq2(x) = x>1 ? x^2 : x<1 ? x^2 : one(x)
clamp2(x, lo=0, hi=1) = x>hi ? oftype(x,hi) : x<lo ? oftype(x,lo) : x
clamp3(x, lo=0, hi=1) = x>=hi ? oftype(x,hi) : x<=lo ? oftype(x,lo) : x
[ForwardDiff.derivative(cl, 1.0) for cl in [x->clamp(x,0,1), clamp2, clamp3]] == [1,1,0]
[central_fdm(5, 1)(cl, 1.0) for cl in [x->clamp(x,0,1), clamp2, clamp3]] ≈ [0.5, 0.5, 0.5]
I'm not sure how often simulating x==1
as in sq2(x)
happens in the wild. Perhaps from some combination like f(x) = relu(x) + 0.1*relu(-x)
?
But clamp
ing parameters to some range is routine. Your calculus teacher would throw an error here, but that's probably not the most helpful response for the computer.
Returning a nonzero derivative here is useful because, if this is some parameter being optimised, it means gradient descent won't get stuck against the wall, when the gradient is away from it. So you can argue that the ability to choose which sub-gradient ForwardDiff
will use is a feature. The 0.5
gradient alla FiniteDifferences
would also be fine for getting un-stuck, but it's very difficult to picture how ForwardDiff
could evaluate both branches, and easy to picture doing so having awful side-effects.
Here is one way to relate the present rule for >(::Dual, ::Real)
and >=(::Dual, ::Real)
to the finite-everything ζ << ε
story. We can say that while the ε
-ball overlaps with both sides, the vote from the longer side (longer by about 2ζ
) always wins by a hair:
----------(==========1==========)---------- abs(x-1) < ε
---------------------1-(=================== x > 1+ζ
+++++++++++++.......... gradient votes, clamp2(1.0)
Trying out the above ==(::Dual, ::Real)
rule, it looks like the tests of this package all pass, except for the ones explicitly testing such rules. It would be interesting to know if this breaks any other uses in the wild. It would also be interesting to think up other pathological examples, maybe I've missed something important.
Also:
-
Another way to talk about this: The problem with
prod2
above, anddet
in support for the determinant function #197, is that they promote accidental zeros of the input to structural zeros. And AD then respects these, and gets the wrong answer. What looked like a simple optimisation when writing for real numbers, has been unintentionally promoted to a constraint on what derivatives are allowed. This is the reverse of the discussion about preserving structural zeros in things likeZygote.gradient(sum∘exp, Diagonal(ones(3)))[1]
. -
Some other packages get this right, such as TaylorSeries.jl, this C++ code, and this Ruby. Some get it wrong (according to me) like this Mathematica code, this paper with Matlab, and this blog post, although he changed his mind from right to wrong. More mathematical treatments seem to regard the tuple
(x,δx)
as inheriting==
from tuples, i.e. they get it right. -
Similar things were also discussed in Problem with
abs
#377, where the example is this:
sq3(z) = abs(-z^2)
sq4(z) = abs2(z)
sq5(z) = z^2
[ForwardDiff.derivative(f, 0.0) for f in [sq, sq2, sq3, sq4, sq5]] == [0,0,0,0,0]
[ForwardDiff.hessian(x -> f(x[1]), [0.0])[1] for f in [sq, sq2, sq3, sq4, sq5]] == [2,2,-2,2,2]
A rule was suggested there in which x > y
behaves differently for x.value == y.value
, breaking such ties by comparing x.partials > y.partials
. In the clamp2
example, whether you get stuck against the wall presumably shouldn't depend on whether you minimise loss(x)
or maximise -loss(x)
, so we probably don't want to compare x.partials .> 0
when only x
is a dual number. But the rule when both x
and y
are dual might be worth some more thought.