-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error with control flow #1236
Comments
IR dump for debugging: julia> adj = @code_adjoint f(1.2);
julia> adj.primal
1: (%3, %4 :: Zygote.Context, %1, %2)
%5 = Zygote._pullback(%4, Base.vect, %2)
%6 = Base.getindex(%5, 1)
%7 = Base.getindex(%5, 2)
%8 = Zygote._pullback(%4, Main.:var"'", %6)
%9 = Base.getindex(%8, 1)
%10 = Base.getindex(%8, 2)
%11 = Zygote._pullback(%4, Base.vect, %2)
%12 = Base.getindex(%11, 1)
%13 = Base.getindex(%11, 2)
%14 = Zygote._pullback(%4, Base.vect, %9, %12)
%15 = Base.getindex(%14, 1)
%16 = Base.getindex(%14, 2)
%17 = Zygote._pullback(%4, Main.:(:), 1, 2)
%18 = Base.getindex(%17, 1)
%19 = Base.getindex(%17, 2)
%20 = Zygote._pullback(%4, Base.iterate, %18)
%21 = Base.getindex(%20, 1)
%22 = Base.getindex(%20, 2)
%23 = %21 === nothing
%24 = Base.not_int(%23)
br 6 (0.0, 1) unless %24
br 2 (%21, 1.0, 0.0, 1)
2: (%25, %26, %27, %59 :: UInt8)
%28 = Zygote._pullback(%4, Zygote.literal_getfield, %25, Val{1}())
%29 = Base.getindex(%28, 1)
%30 = Base.getindex(%28, 2)
%31 = Zygote._pullback(%4, Zygote.literal_getfield, %25, Val{2}())
%32 = Base.getindex(%31, 1)
%33 = Base.getindex(%31, 2)
%34 = Zygote._pullback(%4, Base.getindex, %15, %29)
%35 = Base.getindex(%34, 1)
%36 = Base.getindex(%34, 2)
%37 = Zygote._pullback(%4, Main.:*, %26, %35)
%38 = Base.getindex(%37, 1)
%39 = Base.getindex(%37, 2)
%40 = Zygote._pullback(%4, Main.:<, %29, 2)
%41 = Base.getindex(%40, 1)
%42 = Base.getindex(%40, 2)
br 4 unless %41
br 3
3:
%43 = Zygote._pullback(%4, Base.vect, 1.0)
%44 = Base.getindex(%43, 1)
%45 = Base.getindex(%43, 2)
%46 = Zygote._pullback(%4, Main.:*, %38, %44)
%47 = Base.getindex(%46, 1)
%48 = Base.getindex(%46, 2)
br 5 (%47, 1)
4:
br 5 (%38, 2)
5: (%49, %60 :: UInt8)
%50 = Zygote._pullback(%4, Main.:+, %27, %49)
%51 = Base.getindex(%50, 1)
%52 = Base.getindex(%50, 2)
%53 = Zygote._pullback(%4, Base.iterate, %18, %32)
%54 = Base.getindex(%53, 1)
%55 = Base.getindex(%53, 2)
%56 = %54 === nothing
%57 = Base.not_int(%56)
br 6 (%51, 2) unless %57
br 2 (%54, %38, %51, 2)
6: (%58, %61 :: UInt8)
return %58
julia> adj.adjoint
1: (%1)
%2 = @61 !== 0x01
br 6 (nothing, nothing, nothing) unless %2
br 2 (%1, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
%8 = @60 !== 0x01
%9 = (@55)(%4)
%10 = Zygote.gradindex(%9, 2)
%11 = Zygote.gradindex(%9, 3)
%12 = (@52)(%3)
%13 = Zygote.gradindex(%12, 2)
%14 = Zygote.gradindex(%12, 3)
%15 = Zygote.accum(%6, %10)
%16 = Zygote.accum(%14, %5)
br 4 unless %8
br 3
3:
br 5 (%16)
4:
%17 = (@48)(%14)
%18 = Zygote.gradindex(%17, 2)
%19 = Zygote.gradindex(%17, 3)
%20 = (@45)(%19)
%21 = Zygote.accum(%5, %18)
br 5 (%21)
5: (%22)
%23 = @59 !== 0x01
%24 = (@42)(nothing)
%25 = Zygote.gradindex(%24, 2)
%26 = (@39)(%22)
%27 = Zygote.gradindex(%26, 2)
%28 = Zygote.gradindex(%26, 3)
%29 = (@36)(%28)
%30 = Zygote.gradindex(%29, 2)
%31 = Zygote.gradindex(%29, 3)
%32 = (@33)(%11)
%33 = Zygote.gradindex(%32, 2)
%34 = Zygote.accum(%25, %31)
%35 = (@30)(%34)
%36 = Zygote.gradindex(%35, 2)
%37 = Zygote.accum(%33, %36)
%38 = Zygote.accum(%7, %30)
br 6 (%37, %15, %38) unless %23
br 2 (%13, %37, %27, %15, %38)
6: (%39, %40, %41)
%42 = (@22)(%39)
%43 = Zygote.gradindex(%42, 2)
%44 = Zygote.accum(%40, %43)
%45 = (@19)(%44)
%46 = (@16)(%41)
%47 = Zygote.gradindex(%46, 2)
%48 = Zygote.gradindex(%46, 3)
%49 = (@13)(%48)
%50 = Zygote.gradindex(%49, 2)
%51 = (@10)(%47)
%52 = Zygote.gradindex(%51, 2)
%53 = (@7)(%52)
%54 = Zygote.gradindex(%53, 2)
%55 = Zygote.accum(%50, %54)
%56 = Zygote.tuple(nothing, %55)
return %56 With pullbacks filled in: 1: (Δ)
%2 = @61 !== 0x01
br 6 (nothing, nothing, nothing) unless %2
br 2 (Δ, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
%8 = @60 !== 0x01
%9 = ∂(Base.iterate)(%4)
%10 = Zygote.gradindex(%9, 2)
%11 = Zygote.gradindex(%9, 3)
%12 = ∂(Main.:+)(%3)
%13 = Zygote.gradindex(%12, 2)
%14 = Zygote.gradindex(%12, 3)
%15 = Zygote.accum(%6, %10)
%16 = Zygote.accum(%14, %5)
br 4 unless %8
br 3
3:
br 5 (%16)
4:
%17 = ∂(Main.:*)(%14)
%18 = Zygote.gradindex(%17, 2)
%19 = Zygote.gradindex(%17, 3)
%20 = ∂(Base.vect)(%19)
%21 = Zygote.accum(%5, %18)
br 5 (%21)
5: (%22)
%23 = @59 !== 0x01
%24 = ∂(Main.:<)(nothing)
%25 = Zygote.gradindex(%24, 2)
%26 = ∂(Main.:*)(%22)
%27 = Zygote.gradindex(%26, 2)
%28 = Zygote.gradindex(%26, 3)
%29 = ∂(Base.getindex)(%28)
%30 = Zygote.gradindex(%29, 2)
%31 = Zygote.gradindex(%29, 3)
%32 = ∂(Zygote.literal_getfield)(%11)
%33 = Zygote.gradindex(%32, 2)
%34 = Zygote.accum(%25, %31)
%35 = ∂(Zygote.literal_getfield)(%34)
%36 = Zygote.gradindex(%35, 2)
%37 = Zygote.accum(%33, %36)
%38 = Zygote.accum(%7, %30)
br 6 (%37, %15, %38) unless %23
br 2 (%13, %37, %27, %15, %38)
6: (%39, %40, %41)
%42 = ∂(Base.iterate)(%39)
%43 = Zygote.gradindex(%42, 2)
%44 = Zygote.accum(%40, %43)
%45 = ∂(Main.:(:))(%44)
%46 = ∂(Base.vect)(%41)
%47 = Zygote.gradindex(%46, 2)
%48 = Zygote.gradindex(%46, 3)
%49 = ∂(Base.vect)(%48)
%50 = Zygote.gradindex(%49, 2)
%51 = ∂(Main.:var"'")(%47)
%52 = Zygote.gradindex(%51, 2)
%53 = ∂(Base.vect)(%52)
%54 = Zygote.gradindex(%53, 2)
%55 = Zygote.accum(%50, %54)
%56 = Zygote.tuple(nothing, %55)
return %56 |
Another interesting bit: adding any operation around julia> function f(x)
y = [[x]', [x]]
r = 0.0
o = 1.0
for n in 1:2
o *= y[n]
if n < 2
proj_o = o * [1.0]
else
proj_o = identity(o) # @showgrad also works
end
r += proj_o
end
return r
end
f (generic function with 1 method)
julia> gradient(f, 1.2)
(3.4000000000000004,) And the pullback IR: 1: (%1)
%2 = @64 !== 0x01
br 6 (nothing, nothing, nothing) unless %2
br 2 (%1, nothing, nothing, nothing, nothing)
2: (%3, %4, %5, %6, %7)
%8 = @63 !== 0x01
%9 = (@58)(%4)
%10 = Zygote.gradindex(%9, 2)
%11 = Zygote.gradindex(%9, 3)
%12 = (@55)(%3)
%13 = Zygote.gradindex(%12, 2)
%14 = Zygote.gradindex(%12, 3)
%15 = Zygote.accum(%6, %10)
br 4 unless %8
br 3
3:
%16 = (@51)(%14)
%17 = Zygote.gradindex(%16, 2)
%18 = Zygote.accum(%5, %17)
br 5 (%18)
4:
%19 = (@48)(%14)
%20 = Zygote.gradindex(%19, 2)
%21 = Zygote.gradindex(%19, 3)
%22 = (@45)(%21)
%23 = Zygote.accum(%5, %20)
br 5 (%23)
5: (%24)
%25 = @62 !== 0x01
%26 = (@42)(nothing)
%27 = Zygote.gradindex(%26, 2)
%28 = (@39)(%24)
%29 = Zygote.gradindex(%28, 2)
%30 = Zygote.gradindex(%28, 3)
%31 = (@36)(%30)
%32 = Zygote.gradindex(%31, 2)
%33 = Zygote.gradindex(%31, 3)
%34 = (@33)(%11)
%35 = Zygote.gradindex(%34, 2)
%36 = Zygote.accum(%27, %33)
%37 = (@30)(%36)
%38 = Zygote.gradindex(%37, 2)
%39 = Zygote.accum(%35, %38)
%40 = Zygote.accum(%7, %32)
br 6 (%39, %15, %40) unless %25
br 2 (%13, %39, %29, %15, %40)
6: (%41, %42, %43)
%44 = (@22)(%41)
%45 = Zygote.gradindex(%44, 2)
%46 = Zygote.accum(%42, %45)
%47 = (@19)(%46)
%48 = (@16)(%43)
%49 = Zygote.gradindex(%48, 2)
%50 = Zygote.gradindex(%48, 3)
%51 = (@13)(%50)
%52 = Zygote.gradindex(%51, 2)
%53 = (@10)(%49)
%54 = Zygote.gradindex(%53, 2)
%55 = (@7)(%54)
%56 = Zygote.gradindex(%55, 2)
%57 = Zygote.accum(%52, %56)
%58 = Zygote.tuple(nothing, %57)
return %58 At first glance, the movement of an |
Thanks for looking into it! Interesting to see that adding any operation to that line circumvents the bug. |
This looks similar to #937 (comment), which seems to have fixed itself somehow (on Zygote v0.6.40). |
We've come across a strange bug involving control flow:
Changing the line:
proj_o = o
to:
fixes the issue and outputs:
Version information:
Original issue is here: ITensor/ITensorMPS.jl#73
The text was updated successfully, but these errors were encountered: