Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mutating calls #242

Open
maartenvd opened this issue Oct 29, 2020 · 8 comments
Open

mutating calls #242

maartenvd opened this issue Oct 29, 2020 · 8 comments
Labels
mutability For issues relating to supporting mutability
Milestone

Comments

@maartenvd
Copy link

Is it possible to somehow define derivatives of in-place mutating functions? Eg axpy!(a,x,y) updates y to be y = a*x+y, and therefore it's derivative also needs to be updated.

Apologies if this is the documentation, I missed it.

@oxinabox oxinabox transferred this issue from JuliaDiff/ChainRules.jl Oct 29, 2020
@oxinabox
Copy link
Member

In theory its fine (well there are some rules about things you have to do).
In practice we don't do it because it will break Zygote, due to the fact that Zygote doesn't support mutation,
and if we put in rules that support mutation then Zygote will claim to support mutation and then error.
If we have an AD that does support mutation using ChainRules then we will need to workout a way for Zygote to opt-out of the mutating rules.

@oxinabox oxinabox added the mutability For issues relating to supporting mutability label Oct 29, 2020
@maartenvd
Copy link
Author

That's great news, though I'm not sure in what sense zygote doesn't support this - their buffer type for example mutates in-place https://github.com/FluxML/Zygote.jl/blob/84bf62ea18330389c64d0d918c91d7b897e1a5d8/src/lib/buffer.jl

@oxinabox
Copy link
Member

The Buffer type is special. Its the only thing in Zygote that support mutation.

@oxinabox
Copy link
Member

While I remember.

The two rules of pullbacks for mutating functions

  1. You must undo whatever was changed by this operation in the primal value.
    E.g. A setindex's pullback must set that value back to what it was before.
    E.g. a push! must have a pop! in the pullback.
    Worse case scenario is a full overwrite, which means you need to copy all the data before the primal computational, and then write it all back in during the pullback.

  2. The gradient for the mutation of a primal must be applied during the pullback via mutating the gradient for that thing.
    And it must only be applied that way. If it is also returned then that would have it counted twice.

This second rule does seem pose a problem for mutation support of functions that mutate a value and then don't return it.
In that case we would not have a cotangent passed in to the pullback to mutate.
But such functions are rare, and will (I believe) always be decomposable into a number more primative mutating operations that do return the mutated thing.
Still we can't write custom rules for such things because if this.

@maartenvd
Copy link
Author

And in zygote, what happens when I naively define this pullback for a mutating function that does return the argument it modifies (and the underlying code always uses the return value) ?

Currently I just define my rule so that it's not actually mutating in-place ...

@oxinabox
Copy link
Member

oxinabox commented Dec 4, 2020

If you do that then sometimes Zygote will silently return the wrong answer. I can't off hand tell you what times those are though

@oxinabox
Copy link
Member

oxinabox commented May 24, 2021

Example of what this looks like (if an AD did support mutation) following the rules posted above

function f(x)
    for i in eachindex(x)
        x[i] = x^2
    end
    return x
end

function rrule(typeof(f), x)
    x_is_negative = x .< 0
    function pullback(dy)
        # need to undo the change to `x` incase it is used in another rule.
        x .= sqrt.(x) .* (x_is_negative .* -1)

        # if mutated on the forward need to mutate to store the derivative
        dy .= 1/sqrt.(dy)  # is this math right?

        # return zero not dy as we have already accumulated that by mutating dy
        return NO_FIELDS, ZeroTangent()

    end
    return f(x), pullback
end

@vchuravy
Copy link

So in Enzyme we support mutating calls, aliasing (c.f #350) and activity (c.f. #452). All of these problems are somewhat tightly correlated.

For me a motivating example is supporting GPU codes where outputs are mutated and there is no return value,
I want to inplace accumulate the gradients (since memory pressure on the GPU is a huge issue).

One of the issues @wsmoses have been debating the responsibility of the caller. Since Enzyme doesn't use closures to capture the inputs, but expects the user to pass in both the shadow and the primal value. So if another part of the program mutates I think we currently expect the user to cache it.

Aliasing within the adjoint is solved by caching it, but since we can use LLVM alias-analysis we can limit the amount we need to cache.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mutability For issues relating to supporting mutability
Projects
None yet
Development

No branches or pull requests

4 participants