-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialising weights outside of layer declarations #1879
Comments
Mutating variants sound like a good idea. For applying them to initialized models, we could make some wrapper interface like |
One way with julia> m = Chain(Dense(2,3));
julia> fmap(m; exclude = x -> hasproperty(x, :weight)) do x
x.weight .= (1:3)
x
end
Chain(
Dense(2, 3), # 9 parameters
)
julia> m.layers[1].weight
3×2 Matrix{Float32}:
1.0 1.0
2.0 2.0
3.0 3.0
julia> m.layers[1].bias
3-element Vector{Float32}:
0.0
0.0
0.0 Is going by field name good enough? It might be. Could be wrapped up something like It may also not be worth the hassle making this mutate, since it will only run once. Maybe the |
Yeah my concern was relying on a particular field. We could always make |
My fear is that creating a single function for re-init would be too niche for the reasons discussed already (e.g. different parameters in the same layer wanting different init functions). Mutating variants of init functions makes sense to me, however. They'll at least allow users to do things manually until we can think of good higher-level APIs. |
I'd like to avoid adding another special function you have to remember to overload for any new layer, so that other people can re-weight it. My sketch above is much too rough, but can there be some nice API a bit like that? Most layers call bias If it targets |
I like this. Let's write it so that the keys to ignore are a nested
My thought it no, Flux will depend on Optimisers, so it can still live here. Initialization is specific to neural network models and not optimization. |
So long as we are only doing mutable models, the easy way to apply this only to some branch is probably something like |
True, that's better! |
Small sidetone: I would make the initialization method the first arg to support the |
|
The semantic definition of |
Indeed [re argument order]. I guess the next question is what gets passed to that function. Should this work, or should it get the size?
Is what it returns (like here) always copied back into the old array, or only if you do it? I presume it should return a re-built model alla |
Is there a way to get Functors to only "see" down to a certain level? If function _init_weights!(m)
if m isa Conv
m.weight .*= 2
m.bias .+= 5
end
return m
end Now all that is required is a recursive function ( |
The is_layer_or_leaf(m) = Functors.isleaf(m)
is_layer_or_leaf(::Conv) = true
fmap(_init_weights!, m; exclude=is_layer_or_leaf)
|
That's great! I tried something that's a pretty typical usecase and it worked quite well: julia> is_layer_or_leaf(m) = Functors.isleaf(m)
is_layer_or_leaf (generic function with 1 method)
julia> is_layer_or_leaf(::Conv) = true
is_layer_or_leaf (generic function with 2 methods)
julia> is_layer_or_leaf(::Dense) = true
is_layer_or_leaf (generic function with 3 methods)
julia> l = Chain(Dense(3, 3), Conv((3, 3), 3 => 10))
Chain(
Dense(3 => 3), # 12 parameters
Conv((3, 3), 3 => 10), # 280 parameters
) # Total: 4 arrays, 292 parameters, 1.617 KiB.
julia> function _init_weights!(m::Conv)
m.weight .*= 2
m.bias .+= 5
return m
end
_init_weights! (generic function with 1 method)
julia> function _init_weights!(m::Dense)
m.weight .*= 3
m.bias .+= 4
return m
end
_init_weights! (generic function with 2 methods)
julia> fmap(_init_weights!, l; exclude = is_layer_or_leaf)
Chain(
Dense(3 => 3), # 12 parameters
Conv((3, 3), 3 => 10), # 280 parameters
) # Total: 4 arrays, 292 parameters, 1.617 KiB.
julia> l[1].bias
3-element Vector{Float32}:
4.0
4.0
4.0
julia> l[2].bias
10-element Vector{Float32}:
5.0
5.0
5.0
5.0
5.0
5.0
5.0
5.0
5.0
5.0 If this approach has no problems, then it seems pretty straightforward to define a |
This ambiguity is part of why we don't already have a built-in |
I was trying to give this another go, but I noticed the above example (from here) doesn't work with DenseNet. The error was quite cryptic: julia> model = DenseNet();
julia> fmap(_init_weights!, model; exclude = is_layer_or_leaf)
ERROR: MethodError: no method matching copyto!(::Bool, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Tuple{}, typeof(+), Tuple{Bool, Int64}})
Closest candidates are:
copyto!(::Zygote.Buffer, ::Any) at ~/.julia/packages/Zygote/DkIUK/src/tools/buffer.jl:54
copyto!(::Any, ::Base.Broadcast.Broadcasted{<:StaticArrays.StaticArrayStyle}) at ~/.julia/packages/StaticArrays/G7IlJ/src/broadcast.jl:68
copyto!(::AbstractArray, ::Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}) at broadcast.jl:929
...
Stacktrace:
[1] broadcasted
@ ./broadcast.jl:1319 [inlined]
[2] broadcasted
@ ./broadcast.jl:1317 [inlined]
[3] _init_weights!(m::Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})
@ Main ./REPL[9]:3
[4] #fmap#17
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:50 [inlined]
[5] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[6] iterate
@ ./generator.jl:47 [inlined]
[7] _collect(c::Vector{Any}, itr::Base.Generator{Vector{Any}, Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
@ Base ./array.jl:804
[8] collect_similar
@ ./array.jl:713 [inlined]
[9] map
@ ./abstractarray.jl:2976 [inlined]
[10] _default_walk
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:43 [inlined]
[11] fmap(f::typeof(_init_weights!), x::Vector{Any}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[12] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Vector{Any})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[13] map
@ ./tuple.jl:273 [inlined]
[14] map(::Function, ::NamedTuple{(:layers,), Tuple{Vector{Any}}})
@ Base ./namedtuple.jl:218
[15] _default_walk(f::Function, x::Chain{Vector{Any}})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
[16] fmap(f::typeof(_init_weights!), x::Chain{Vector{Any}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[17] #18
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:50 [inlined]
[18] map
@ ./tuple.jl:274 [inlined]
[19] _default_walk
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:43 [inlined]
[20] fmap(f::typeof(_init_weights!), x::Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[21] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[22] map
@ ./tuple.jl:273 [inlined]
[23] map(::Function, ::NamedTuple{(:layers,), Tuple{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}})
@ Base ./namedtuple.jl:218
[24] _default_walk(f::Function, x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
[25] fmap(f::typeof(_init_weights!), x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[26] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[27] map
@ ./tuple.jl:273 [inlined]
[28] map(::Function, ::NamedTuple{(:layers,), Tuple{Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}}})
@ Base ./namedtuple.jl:218
[29] _default_walk(f::Function, x::DenseNet)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
[30] fmap(f::typeof(_init_weights!), x::DenseNet; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[31] top-level scope
@ REPL[16]:1
[32] top-level scope
@ ~/.julia/packages/CUDA/GGwVa/src/initialization.jl:52 Am I missing something here? Why isn't this working the way it's supposed to? |
Most likely you are trying to accumulate into a |
Following the discussions in FluxML/Metalhead.jl#119, I realised that currently there is no way for the user to programmatically pass in weight initialisation strategies for layers in a
Chain
-like structure based on the type of the layer (after the layer has been declared already, that is). This would be quite the useful feature to have given that many recent models use specific weight initialisations for some types of layers.An initial idea that I had was to add a mutating version of the existing initialisation functions. Then we could have a wrapper function that mutated the weights of the already existing layer instead of having to copy over an entirely new layer just to change the initial weights. I'm unsure if this clashes with something (and I also don't really have ideas on if there are efficient ways to do this already via existing functionalities), so opening this up for discussion in case there's some conflict before I sit down to write it up.
\cc @darsnack
The text was updated successfully, but these errors were encountered: