Skip to content

Flux.destructure gives MethodError when used with non-trainable parameters #1553

Closed
@ctessum

Description

@ctessum

Hello,

I've experienced an error (using Flux 0.11.6) with the following code:

using Flux

struct train_part
    a
    b
end

function (a::train_part)(x) 
    a.a * a.b * x
end

Flux.@functor train_part (a,) # Specify that only the 'a' matrix is trainable.

m = Chain(
    Dense(2, 2, tanh),
    train_part(zeros(2,2), zeros(2,2)),
)

Flux.destructure(m)

This is the error message:

ERROR: MethodError: no method matching train_part(::Matrix{Float64})
Closest candidates are:
  train_part(::Any, ::Any) at REPL[12]:2
Stacktrace:
  [1] (::var"#3#4")(y::NamedTuple{(:a,), Tuple{Matrix{Float64}}})
    @ Main ~/.julia/packages/Functors/YlETM/src/functor.jl:12
  [2] fmap1(f::Function, x::train_part)
    @ Functors ~/.julia/packages/Functors/YlETM/src/functor.jl:30
  [3] #fmap#13
    @ ~/.julia/packages/Functors/YlETM/src/functor.jl:35 [inlined]
  [4] (::Functors.var"#14#15"{IdDict{Any, Any}, Flux.var"#33#35"{Zygote.Buffer{Any, Vector{Any}}}})(x::train_part)
    @ Functors ~/.julia/packages/Functors/YlETM/src/functor.jl:35
  [5] map
    @ ./tuple.jl:214 [inlined]
  [6] fmap1(f::Function, x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, train_part}})
    @ Functors ~/.julia/packages/Functors/YlETM/src/functor.jl:30
  [7] fmap(f::Flux.var"#33#35"{Zygote.Buffer{Any, Vector{Any}}}, x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, train_part}}; cache::IdDict{Any, Any})
    @ Functors ~/.julia/packages/Functors/YlETM/src/functor.jl:35
  [8] fmap
    @ ~/.julia/packages/Functors/YlETM/src/functor.jl:34 [inlined]
  [9] destructure(m::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, train_part}})
    @ Flux ~/.julia/packages/Flux/goUGu/src/utils.jl:409
 [10] top-level scope
    @ REPL[16]:1

I think what might be happening is that destructure extracts one value from train_part, because only one of the values in train_part is trainable, and then it tries to reconstruct it with just the one variable, which does not work.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions