Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with control flow #1236

Open
mtfishman opened this issue Jun 3, 2022 · 4 comments
Open

Error with control flow #1236

mtfishman opened this issue Jun 3, 2022 · 4 comments
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@mtfishman
Copy link

We've come across a strange bug involving control flow:

julia> using FiniteDifferences

julia> using Zygote

julia> function f(x)
         y = [[x]', [x]]
         r = 0.0
         o = 1.0
         for n in 1:2
           o *= y[n]
           if n < 2
             proj_o = o * [1.0]
           else
             # Error
             proj_o = o
             # Fix
             # proj_o = o * 1.0
           end
           r += proj_o
         end
         return r
       end
f (generic function with 1 method)

julia> x = 1.2
1.2

julia> f(x)
2.6399999999999997

julia> central_fdm(5, 1)(f, x)
3.4000000000000967

julia> f'(x)
ERROR: MethodError: no method matching +(::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at ~/software/julia-1.7.3/share/julia/base/operators.jl:655
  +(::Union{Float16, Float32, Float64}, ::BigFloat) at ~/software/julia-1.7.3/share/julia/base/mpfr.jl:413
  +(::ChainRulesCore.Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/GUvJT/src/tangent_arithmetic.jl:146
  ...
Stacktrace:
 [1] accum(x::Float64, y::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
   @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:17
 [2] Pullback
   @ ./REPL[16]:15 [inlined]
 [3] (::typeof((f)))(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [4] (::Zygote.var"#52#53"{typeof((f))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41
 [5] (::Zygote.var"#54#55"{typeof(f)})(x::Float64)
   @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:83
 [6] top-level scope
   @ REPL[20]:1

Changing the line:

proj_o = o

to:

proj_o = o * 1.0

fixes the issue and outputs:

julia> f(x)
2.6399999999999997

julia> central_fdm(5, 1)(f, x)
3.4000000000000967

julia> f'(x)
3.4000000000000004

Version information:

julia> versioninfo()
Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) E-2176M  CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
Environment:
  JULIA_EDITOR = vim

julia> using Pkg

julia> Pkg.status("Zygote")
      Status `~/.julia/environments/v1.7/Project.toml`
  [e88e6eb3] Zygote v0.6.40

julia> Pkg.status("ChainRules")
      Status `~/.julia/environments/v1.7/Project.toml`
  [082447d4] ChainRules v1.35.1

Original issue is here: ITensor/ITensorMPS.jl#73

@ToucheSir ToucheSir added bug Something isn't working help wanted Extra attention is needed labels Jun 5, 2022
@ToucheSir
Copy link
Member

ToucheSir commented Jun 5, 2022

IR dump for debugging:

julia> adj = @code_adjoint f(1.2);

julia> adj.primal
1: (%3, %4 :: Zygote.Context, %1, %2)
  %5 = Zygote._pullback(%4, Base.vect, %2)
  %6 = Base.getindex(%5, 1)
  %7 = Base.getindex(%5, 2)
  %8 = Zygote._pullback(%4, Main.:var"'", %6)
  %9 = Base.getindex(%8, 1)
  %10 = Base.getindex(%8, 2)
  %11 = Zygote._pullback(%4, Base.vect, %2)
  %12 = Base.getindex(%11, 1)
  %13 = Base.getindex(%11, 2)
  %14 = Zygote._pullback(%4, Base.vect, %9, %12)
  %15 = Base.getindex(%14, 1)
  %16 = Base.getindex(%14, 2)
  %17 = Zygote._pullback(%4, Main.:(:), 1, 2)
  %18 = Base.getindex(%17, 1)
  %19 = Base.getindex(%17, 2)
  %20 = Zygote._pullback(%4, Base.iterate, %18)
  %21 = Base.getindex(%20, 1)
  %22 = Base.getindex(%20, 2)
  %23 = %21 === nothing
  %24 = Base.not_int(%23)
  br 6 (0.0, 1) unless %24
  br 2 (%21, 1.0, 0.0, 1)
2: (%25, %26, %27, %59 :: UInt8)
  %28 = Zygote._pullback(%4, Zygote.literal_getfield, %25, Val{1}())
  %29 = Base.getindex(%28, 1)
  %30 = Base.getindex(%28, 2)
  %31 = Zygote._pullback(%4, Zygote.literal_getfield, %25, Val{2}())
  %32 = Base.getindex(%31, 1)
  %33 = Base.getindex(%31, 2)
  %34 = Zygote._pullback(%4, Base.getindex, %15, %29)
  %35 = Base.getindex(%34, 1)
  %36 = Base.getindex(%34, 2)
  %37 = Zygote._pullback(%4, Main.:*, %26, %35)
  %38 = Base.getindex(%37, 1)
  %39 = Base.getindex(%37, 2)
  %40 = Zygote._pullback(%4, Main.:<, %29, 2)
  %41 = Base.getindex(%40, 1)
  %42 = Base.getindex(%40, 2)
  br 4 unless %41
  br 3
3:
  %43 = Zygote._pullback(%4, Base.vect, 1.0)
  %44 = Base.getindex(%43, 1)
  %45 = Base.getindex(%43, 2)
  %46 = Zygote._pullback(%4, Main.:*, %38, %44)
  %47 = Base.getindex(%46, 1)
  %48 = Base.getindex(%46, 2)
  br 5 (%47, 1)
4:
  br 5 (%38, 2)
5: (%49, %60 :: UInt8)
  %50 = Zygote._pullback(%4, Main.:+, %27, %49)
  %51 = Base.getindex(%50, 1)
  %52 = Base.getindex(%50, 2)
  %53 = Zygote._pullback(%4, Base.iterate, %18, %32)
  %54 = Base.getindex(%53, 1)
  %55 = Base.getindex(%53, 2)
  %56 = %54 === nothing
  %57 = Base.not_int(%56)
  br 6 (%51, 2) unless %57
  br 2 (%54, %38, %51, 2)
6: (%58, %61 :: UInt8)
  return %58

julia> adj.adjoint
1: (%1)
  %2 = @61 !== 0x01
  br 6 (nothing, nothing, nothing) unless %2
  br 2 (%1, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
  %8 = @60 !== 0x01
  %9 = (@55)(%4)
  %10 = Zygote.gradindex(%9, 2)
  %11 = Zygote.gradindex(%9, 3)
  %12 = (@52)(%3)
  %13 = Zygote.gradindex(%12, 2)
  %14 = Zygote.gradindex(%12, 3)
  %15 = Zygote.accum(%6, %10)
  %16 = Zygote.accum(%14, %5)
  br 4 unless %8
  br 3
3:
  br 5 (%16)
4:
  %17 = (@48)(%14)
  %18 = Zygote.gradindex(%17, 2)
  %19 = Zygote.gradindex(%17, 3)
  %20 = (@45)(%19)
  %21 = Zygote.accum(%5, %18)
  br 5 (%21)
5: (%22)
  %23 = @59 !== 0x01
  %24 = (@42)(nothing)
  %25 = Zygote.gradindex(%24, 2)
  %26 = (@39)(%22)
  %27 = Zygote.gradindex(%26, 2)
  %28 = Zygote.gradindex(%26, 3)
  %29 = (@36)(%28)
  %30 = Zygote.gradindex(%29, 2)
  %31 = Zygote.gradindex(%29, 3)
  %32 = (@33)(%11)
  %33 = Zygote.gradindex(%32, 2)
  %34 = Zygote.accum(%25, %31)
  %35 = (@30)(%34)
  %36 = Zygote.gradindex(%35, 2)
  %37 = Zygote.accum(%33, %36)
  %38 = Zygote.accum(%7, %30)
  br 6 (%37, %15, %38) unless %23
  br 2 (%13, %37, %27, %15, %38)
6: (%39, %40, %41)
  %42 = (@22)(%39)
  %43 = Zygote.gradindex(%42, 2)
  %44 = Zygote.accum(%40, %43)
  %45 = (@19)(%44)
  %46 = (@16)(%41)
  %47 = Zygote.gradindex(%46, 2)
  %48 = Zygote.gradindex(%46, 3)
  %49 = (@13)(%48)
  %50 = Zygote.gradindex(%49, 2)
  %51 = (@10)(%47)
  %52 = Zygote.gradindex(%51, 2)
  %53 = (@7)(%52)
  %54 = Zygote.gradindex(%53, 2)
  %55 = Zygote.accum(%50, %54)
  %56 = Zygote.tuple(nothing, %55)
  return %56

With pullbacks filled in:

1: (Δ)
  %2 = @61 !== 0x01
  br 6 (nothing, nothing, nothing) unless %2
  br 2 (Δ, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
  %8 = @60 !== 0x01
  %9 = (Base.iterate)(%4)
  %10 = Zygote.gradindex(%9, 2)
  %11 = Zygote.gradindex(%9, 3)
  %12 = (Main.:+)(%3)
  %13 = Zygote.gradindex(%12, 2)
  %14 = Zygote.gradindex(%12, 3)
  %15 = Zygote.accum(%6, %10)
  %16 = Zygote.accum(%14, %5)
  br 4 unless %8
  br 3
3:
  br 5 (%16)
4:
  %17 = (Main.:*)(%14)
  %18 = Zygote.gradindex(%17, 2)
  %19 = Zygote.gradindex(%17, 3)
  %20 = (Base.vect)(%19)
  %21 = Zygote.accum(%5, %18)
  br 5 (%21)
5: (%22)
  %23 = @59 !== 0x01
  %24 = (Main.:<)(nothing)
  %25 = Zygote.gradindex(%24, 2)
  %26 = (Main.:*)(%22)
  %27 = Zygote.gradindex(%26, 2)
  %28 = Zygote.gradindex(%26, 3)
  %29 = (Base.getindex)(%28)
  %30 = Zygote.gradindex(%29, 2)
  %31 = Zygote.gradindex(%29, 3)
  %32 = (Zygote.literal_getfield)(%11)
  %33 = Zygote.gradindex(%32, 2)
  %34 = Zygote.accum(%25, %31)
  %35 = (Zygote.literal_getfield)(%34)
  %36 = Zygote.gradindex(%35, 2)
  %37 = Zygote.accum(%33, %36)
  %38 = Zygote.accum(%7, %30)
  br 6 (%37, %15, %38) unless %23
  br 2 (%13, %37, %27, %15, %38)
6: (%39, %40, %41)
  %42 = (Base.iterate)(%39)
  %43 = Zygote.gradindex(%42, 2)
  %44 = Zygote.accum(%40, %43)
  %45 = (Main.:(:))(%44)
  %46 = (Base.vect)(%41)
  %47 = Zygote.gradindex(%46, 2)
  %48 = Zygote.gradindex(%46, 3)
  %49 = (Base.vect)(%48)
  %50 = Zygote.gradindex(%49, 2)
  %51 = (Main.:var"'")(%47)
  %52 = Zygote.gradindex(%51, 2)
  %53 = (Base.vect)(%52)
  %54 = Zygote.gradindex(%53, 2)
  %55 = Zygote.accum(%50, %54)
  %56 = Zygote.tuple(nothing, %55)
  return %56

@ToucheSir
Copy link
Member

ToucheSir commented Jun 5, 2022

Another interesting bit: adding any operation around o in the problematic line resolves the issue:

julia> function f(x)
         y = [[x]', [x]]
         r = 0.0
         o = 1.0
         for n in 1:2
           o *= y[n]
           if n < 2
             proj_o = o * [1.0]
           else
             proj_o = identity(o) # @showgrad also works
           end
           r += proj_o
         end
         return r
       end
f (generic function with 1 method)

julia> gradient(f, 1.2)
(3.4000000000000004,)

And the pullback IR:

1: (%1)
  %2 = @64 !== 0x01
  br 6 (nothing, nothing, nothing) unless %2
  br 2 (%1, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
  %8 = @63 !== 0x01
  %9 = (@58)(%4)
  %10 = Zygote.gradindex(%9, 2)
  %11 = Zygote.gradindex(%9, 3)
  %12 = (@55)(%3)
  %13 = Zygote.gradindex(%12, 2)
  %14 = Zygote.gradindex(%12, 3)
  %15 = Zygote.accum(%6, %10)
  br 4 unless %8
  br 3
3:
  %16 = (@51)(%14)
  %17 = Zygote.gradindex(%16, 2)
  %18 = Zygote.accum(%5, %17)
  br 5 (%18)
4:
  %19 = (@48)(%14)
  %20 = Zygote.gradindex(%19, 2)
  %21 = Zygote.gradindex(%19, 3)
  %22 = (@45)(%21)
  %23 = Zygote.accum(%5, %20)
  br 5 (%23)
5: (%24)
  %25 = @62 !== 0x01
  %26 = (@42)(nothing)
  %27 = Zygote.gradindex(%26, 2)
  %28 = (@39)(%24)
  %29 = Zygote.gradindex(%28, 2)
  %30 = Zygote.gradindex(%28, 3)
  %31 = (@36)(%30)
  %32 = Zygote.gradindex(%31, 2)
  %33 = Zygote.gradindex(%31, 3)
  %34 = (@33)(%11)
  %35 = Zygote.gradindex(%34, 2)
  %36 = Zygote.accum(%27, %33)
  %37 = (@30)(%36)
  %38 = Zygote.gradindex(%37, 2)
  %39 = Zygote.accum(%35, %38)
  %40 = Zygote.accum(%7, %32)
  br 6 (%39, %15, %40) unless %25
  br 2 (%13, %39, %29, %15, %40)
6: (%41, %42, %43)
  %44 = (@22)(%41)
  %45 = Zygote.gradindex(%44, 2)
  %46 = Zygote.accum(%42, %45)
  %47 = (@19)(%46)
  %48 = (@16)(%43)
  %49 = Zygote.gradindex(%48, 2)
  %50 = Zygote.gradindex(%48, 3)
  %51 = (@13)(%50)
  %52 = Zygote.gradindex(%51, 2)
  %53 = (@10)(%49)
  %54 = Zygote.gradindex(%53, 2)
  %55 = (@7)(%54)
  %56 = Zygote.gradindex(%55, 2)
  %57 = Zygote.accum(%52, %56)
  %58 = Zygote.tuple(nothing, %57)
  return %58

At first glance, the movement of an accum from block 2 to 3 (nominally the else branch of the if n < 2) seems like the biggest culprit. I was hoping this wouldn't be a compiler issue, but it is increasingly looking like it might be one.

@mtfishman
Copy link
Author

Thanks for looking into it! Interesting to see that adding any operation to that line circumvents the bug.

@mcabbott
Copy link
Member

mcabbott commented Jun 6, 2022

This looks similar to #937 (comment), which seems to have fixed itself somehow (on Zygote v0.6.40).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants