Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BoundsError calling Flux.reset! #1297

Open
mentics opened this issue Aug 29, 2022 · 2 comments · May be fixed by JuliaDiff/ChainRules.jl#569
Open

BoundsError calling Flux.reset! #1297

mentics opened this issue Aug 29, 2022 · 2 comments · May be fixed by JuliaDiff/ChainRules.jl#569
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration

Comments

@mentics
Copy link

mentics commented Aug 29, 2022

ERROR: BoundsError: attempt to access Tuple{} at index [0] thrown from Zygote code when calling Flux.reset! in a Flux loss function.

Julia 1.8.0
Zygote 0.6.45
Flux v0.13.5

Test case:

using Flux
function test()
    model = Dense(1 => 1)
    params = Flux.params(model)
    function loss(x, y)
        Flux.reset!(model)
        Flux.Losses.mse(model(x), y)
    end
    Flux.train!(loss, params, [([1.0],[1.0])], Descent())
end

Stacktrace:

ERROR: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
  [1] getindex(t::Tuple, i::Int64)
    @ Base .\tuple.jl:29
  [2] last(a::Tuple{})
    @ Base .\abstractarray.jl:479
  [3] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(foldl), op::Base.var"#57#58"{typeof(Flux.reset!)}, x::Tuple{}; init::Nothing)
    @ ChainRules C:\Users\joel\.julia\packages\ChainRules\fgVxV\src\rulesets\Base\mapreduce.jl:448
  [4] chain_rrule_kw(::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::Function, ::NamedTuple{(:init,), Tuple{Nothing}}, ::Function, ::Function, ::Vararg{Any})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\chainrules.jl:230
  [5] macro expansion
    @ C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0 [inlined]
  [6] _pullback(::Zygote.Context{true}, ::Base.var"#foldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:9
  [7] _pullback
    @ .\tuple.jl:555 [inlined]
  [8] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::Tuple{})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
  [9] _pullback
    @ C:\Users\joel\.julia\packages\Flux\EXOFx\src\layers\recurrent.jl:180 [inlined]
 [10] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Matrix{Float32})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
 [11] _pullback
    @ .\abstractarray.jl:2774 [inlined]
 [12] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, typeof(identity)}})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
 [13] _pullback
    @ C:\Users\joel\.julia\packages\Flux\EXOFx\src\layers\recurrent.jl:180 [inlined]
 [14] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
 [15] _pullback
    @ C:\data\julia\journey\modules\lev2-util\ml\TryFlux.jl:23 [inlined]
 [16] _pullback(::Zygote.Context{true}, ::TryFlux.var"#loss#21"{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
 [17] _apply(::Function, ::Vararg{Any})
    @ Core .\boot.jl:816
 [18] adjoint
    @ C:\Users\joel\.julia\packages\Zygote\qGFGD\src\lib\lib.jl:203 [inlined]
 [19] _pullback
    @ C:\Users\joel\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [20] _pullback
    @ C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:120 [inlined]
 [21] _pullback(::Zygote.Context{true}, ::Flux.Optimise.var"#37#40"{TryFlux.var"#loss#21"{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, Tuple{Vector{Float64}, Vector{Float64}}})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
 [22] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface.jl:373
 [23] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface.jl:96
 [24] macro expansion
    @ C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:119 [inlined]
 [25] macro expansion
    @ C:\Users\joel\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
 [26] train!(loss::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, data::Vector{Tuple{Vector{Float64}, Vector{Float64}}}, opt::Flux.Optimise.Descent; cb::Flux.Optimise.var"#38#41")
    @ Flux.Optimise C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:117
 [27] train!
    @ C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:113 [inlined]
 [28] test()
    @ TryFlux C:\data\julia\journey\modules\lev2-util\ml\TryFlux.jl:26
 [29] top-level scope
    @ REPL[12]:1

Removing the call to Flux.reset! removes the error.

@mcabbott
Copy link
Member

This is from inside the foldl rule, on an empty tuple. Which comes from this foreach on a leaf node:

reset!(m::Recur) = (m.state = m.cell.state0)
reset!(m) = foreach(reset!, functor(m)[1])
julia> Functors.functor(Dense(1 => 1).weight)[1]
()

Still fails with JuliaDiff/ChainRules.jl#569 with this stacktrace:

julia> test()
ERROR: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
  [1] getindex(t::Tuple, i::Int64)
    @ Base ./tuple.jl:29
  [2] last(a::Tuple{})
    @ Base ./abstractarray.jl:500
  [3] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::Base.var"#57#58"{typeof(Flux.reset!)}, init::Base._InitialValue, x::Tuple{Nothing})
    @ ChainRules ~/.julia/packages/ChainRules/fK4AU/src/rulesets/Base/mapreduce.jl:465
  [4] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::Base.var"#57#58"{typeof(Flux.reset!)}, init::Nothing, x::Tuple{})
    @ ChainRules ~/.julia/packages/ChainRules/fK4AU/src/rulesets/Base/mapreduce.jl:488
  [5] chain_rrule
    @ ~/.julia/packages/Zygote/qGFGD/src/compiler/chainrules.jl:218 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0 [inlined]
  [7] _pullback(::Zygote.Context{true}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Nothing, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:9
  [8] _pullback
    @ ./reduce.jl:170 [inlined]
  [9] _pullback(::Zygote.Context{true}, ::Base.var"##mapfoldl#286", ::Nothing, ::typeof(mapfoldl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [10] _pullback
    @ ./reduce.jl:170 [inlined]
 [11] _pullback(::Zygote.Context{true}, ::Base.var"#mapfoldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(mapfoldl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [12] _pullback
    @ ./reduce.jl:193 [inlined]
 [13] _pullback(::Zygote.Context{true}, ::Base.var"##foldl#287", ::Base.Pairs{Symbol, Nothing, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Nothing}}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [14] _pullback
    @ ./reduce.jl:193 [inlined]
 [15] _pullback(::Zygote.Context{true}, ::Base.var"#foldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [16] _pullback
    @ ./tuple.jl:602 [inlined]
 [17] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [18] _pullback
    @ ~/.julia/packages/Flux/EXOFx/src/layers/recurrent.jl:180 [inlined]
 [19] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [20] _pullback
    @ ./abstractarray.jl:3036 [inlined]
 [21] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, typeof(identity)}})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [22] _pullback
    @ ~/.julia/packages/Flux/EXOFx/src/layers/recurrent.jl:180 [inlined]
 [23] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})

One possible fix is:

julia> ChainRulesCore.@non_differentiable foreach(f, ::Tuple{})

julia> Zygote.refresh()

julia> test()

More generally should Flux be differentiating inside reset! at all?

@mcabbott mcabbott added bug Something isn't working ChainRules adjoint -> rrule, and further integration labels Aug 29, 2022
@ToucheSir
Copy link
Member

More generally should Flux be differentiating inside reset! at all?

My understanding of FluxML/Flux.jl#808 (comment) is that it's intentional to allow initial state to be trainable, but perhaps there's another way for us to make that work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants