Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong results with higher order pullback -- chained if/else in accum #937

Open
axsk opened this issue Mar 30, 2021 · 9 comments
Open

Wrong results with higher order pullback -- chained if/else in accum #937

axsk opened this issue Mar 30, 2021 · 9 comments
Labels
bug Something isn't working

Comments

@axsk
Copy link

axsk commented Mar 30, 2021

Here comes a strange one... This is the minimal working example I could find to get the erroneous result
The loss function computes the gradients of some function (think of it as a mixture of eucl. distances) at each point given by the columns of x/data. I extract the gradients with the dfdx, ... statement and return it as result of the loss function.

In a later stage I now wish to optimize some model over this loss function, so I need the derivative of this spatial gradient wrt. to the model parameters. Here the model is just the identity and the loss does not depend on the parameter z.

However, when evaluating the loss via/inside the pullback function it returns another result.

Strangely this only happens when the sums involve the abs2 terms and I also need to subtract both, aa and bb, otherwise the results l1 and l2 are the same 😮

using Zygote

function mwe()
    a = ones(1) * 0
    b = ones(1) * .5
    data = ones(1)
    
    function loss()
        y, pb = Zygote.pullback(data) do x
            aa = sum(abs2, x .- a, dims=1)
            bb = sum(abs2, x .- b, dims=1)
            r = 1 .- aa .- bb
        end
        dfdx,  = pb(data)
        dfdx
    end
    
    l1 = sum(loss())
    
    l2, pb = Zygote.pullback() do 
        sum(loss())
    end
    
    @show l1, l2
    @assert l1 == l2
end
julia> mwe()
(l1, l2) = (-3.0, -2.0)
ERROR: AssertionError: l1 == l2

Any feedback is welcome :)

@axsk axsk changed the title Inconsistent results with pullback Inconsistent results with higher order pullback Apr 9, 2021
@axsk
Copy link
Author

axsk commented Apr 9, 2021

loss() = d/dx (1 - x^2 - (x-1/2)^2) = -4x + 1, and since x=data=1 we have that loss() = -3. The -2 returned by the outer pullback is wrong

@axsk axsk changed the title Inconsistent results with higher order pullback Wrong results with higher order pullback Apr 9, 2021
@DhairyaLGandhi
Copy link
Member

Huh interesting. This is a bit odd.

This might be because of an incorrect abs2 adjoint definition or erroneous forward pass rewriting or global lookup

cc @willtebbutt @mzgubic @oxinabox

@mzgubic
Copy link
Collaborator

mzgubic commented Apr 9, 2021

I tried commenting out the adjoints for abs2 but that didn't solve it. It also breaks without arrays

julia> function mwe()
           a = 0
           b = .5
           data = 1.0
           
           function loss()
               y, pb = Zygote.pullback(data) do x
                   aa = abs2(x - a)
                   bb = abs2(x - b)
                   r = 1 - aa - bb
               end
               dfdx,  = pb(data)
               dfdx
           end
           
           l1 = sum(loss())
           
           l2, pb = Zygote.pullback() do 
               sum(loss())
           end
           
           @show l1, l2
           @assert l1 == l2
       end
mwe (generic function with 1 method)

julia> mwe()
(l1, l2) = (-3.0, -2.0)
ERROR: AssertionError: l1 == l2
Stacktrace:
 [1] mwe()
   @ Main ./REPL[9]:23
 [2] top-level scope
   @ REPL[10]:1

@DhairyaLGandhi
Copy link
Member

I tried running a few experiments with explicit parameters to see if it was global lookups which didn't work either

@axsk
Copy link
Author

axsk commented Apr 12, 2021

Might this be related to perturbation confusion (c.f. JuliaDiff/ForwardDiff.jl#83)?

@axsk
Copy link
Author

axsk commented Apr 12, 2021

Might this be related to perturbation confusion (c.f. JuliaDiff/ForwardDiff.jl#83)?

The example from the referenced paper (which mentions the problem for Forward-Mode) gradient(x->x * gradient(y->x+y, 1)[1], 1) == 1 works fine, so it's probably not that problem..

@axsk
Copy link
Author

axsk commented Apr 13, 2021

The outer gradients are computed wrong too (not in this example though, since there is no outer gradient).

@DhairyaLGandhi
Copy link
Member

I'm going to have to jump deeper in here, will dig in

@mcabbott
Copy link
Member

Following up on @mzgubic's simplification, here is a more minimal example, and one without a second derivative at all. The problem appears to be that Zygote is getting confused by the if statements in accum, and silently giving wrong answers, which is slightly disturbing.

julia> using Zygote

julia> function mmwe()
           f() = gradient(x -> 13*x + x, 17)[1]
           α = f()
           β, _ = pullback(f)
           @show α, β
           nothing
       end;

julia> mmwe()
(α, β) = (14, 2)

julia> let
          α = Zygote.accum(1,2)
          β, _ = pullback(Zygote.accum,1,2)
          @show α, β
       end;
(α, β) = (3, 2)

julia> Zygote.accum(x, y) =
         x === nothing ? y :
         # y === nothing ? x :   # this won't fix mwe(), it needs accum(::Float64, ::Missing)
         x + y

julia> mmwe()
(α, β) = (14, 14)

@mcabbott mcabbott changed the title Wrong results with higher order pullback Wrong results with higher order pullback -- chained if/else in accum Sep 9, 2021
@mcabbott mcabbott added the bug Something isn't working label Jul 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants