Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Freezing layers at model construction time #1931

Open
jondeuce opened this issue Apr 6, 2022 · 9 comments
Open

Freezing layers at model construction time #1931

jondeuce opened this issue Apr 6, 2022 · 9 comments

Comments

@jondeuce
Copy link
Contributor

jondeuce commented Apr 6, 2022

There have been several issues/PRs related to freezing model parameters:

  1. freeze parameters #1022
  2. How to keep weights of parts of a model fixed under Flux.train! #1001
  3. Implement APIs of freeze parameters and freeze layers #1101
  4. delete! for Params Zygote.jl#505
  5. Per-leaf freezing Optimisers.jl#49

Right now, the recommendation made in the documentation is to manually specify which parameters should not be trained using some combination of Flux.params and Zygote.delete!.

While this works, it is somewhat inflexible in several respects:

  1. Training routines must be aware of the model architecture in order to select which parameters to freeze
  2. Specifying that a layer is frozen is often much more convenient to do at model construction time, particularly if the frozen layer is nested deeply inside the model
  3. It is not clear how the current approach would fit into the functional-style approach which is coming in v0.13, since Params would no longer be used at all (would one need to e.g. fmap over a model and somehow mark specific layers as frozen before passing to gradient?)

For these reasons, I often find myself defining a Frozen layer (similar to #1001) which looks something like this:

using Flux
using Flux: @adjoint

struct Frozen{F}
    f::F
end
Flux.@functor Frozen # need functor for e.g. `fmap`
Flux.trainable(::Frozen) = NamedTuple() # no trainable parameters

# Something like `whitebox_apply` is required to explicitly treat `f` as a "white box":
# propagate gradients through `f`, but treat `f` itself as a constant functor
(l::Frozen)(xs...) = whitebox_apply(l.f, xs...)

whitebox_apply(f, xs...) = f(xs...)

@adjoint function whitebox_apply(f, xs...)
    y, J = Flux.pullback(f, xs...)
    y, Δ -> (nothing, J(Δ)...)
end

A frozen layer l::Frozen wraps a functor f and has two properties:

  1. l(x) = f(x) is differentiable with respect to x (as opposed to e.g. l(x) = dropgrad(f(x)) which would treat f(x) as constant)
  2. f is treated as a constant functor: gradients of l(x) with respect to parameters internal to f return zero

Below is some test code to illustrate how this layer should behave:

Examples/tests
x = rand(Float32, 2)
l1 = Dense(2, 3, tanh)
l2 = Dense(3, 4, tanh)
l3 = Dense(4, 2, identity)

m0 = Chain(l1, l2, l3)
m1 = Chain(l1, Frozen(l2), l3) # identical to `m0` but with the middle layer frozen

p0 = Flux.params(m0)
p1 = Flux.params(m1)
pfree = Flux.params(l1, l3)
pfrozen = Flux.params(l2)

# Basics
@assert all(p  p1 for p in pfree) # free params are present
@assert all(p  p1 for p in pfrozen) # frozen params are not

∇p1 = gradient(() -> sum(m1(x)), pfrozen)
@assert all(∇p1[p] === nothing for p in pfrozen) # frozen params have zero gradients, even if passed to `gradient` explicitly

∇p1 = gradient(() -> sum(m1(x)), p1)
@assert all(haskey(∇p1, p) for p in pfree) # free params have gradients
@assert !any(haskey(∇p1, p) for p in pfrozen) # frozen params do not have gradients

∇p0 = gradient(() -> sum(m0(x)), p0)
@assert all(∇p0[p]  ∇p1[p] for p in pfree) # gradients are equal for free params

# This loss is constant as a function of `pfree`: `m0` and `m1` co-vary exactly as `pfree` changes,
# and therefore the difference `m0(x) - m1(x)` is zero with zero gradient w.r.t. `pfree`.
# However, since `m1` is treated as a constant function of `pfrozen` but `m0` is not,
# the gradient of `m0(x) - m1(x)` is nonzero w.r.t. `pfrozen`.
loss = () -> sum(m0(x) - m1(x))

∇p0 = gradient(loss, p0)
@assert all(iszero(∇p0[p]) for p in pfree) # gradient == 0 for free parameters
@assert !any(iszero(∇p0[p]) for p in pfrozen) # gradient != 0 for frozen parameters

∇p1 = gradient(loss, p1)
@assert all(iszero(∇p1[p]) for p in pfree) # gradient == 0 for free parameters
@assert !any(haskey(∇p1, p) for p in pfrozen) # gradients not present for frozen parameters
@assert all(∇p0[p]  ∇p1[p] for p in pfree) # gradients are equal for free params

If there is interest in including a layer like Frozen into Flux I would be happy to make a PR. Of course, if there is an easy way to do what I'm describing which I have overlooked, please do let me know and I'll close this issue.

@darsnack
Copy link
Member

darsnack commented Apr 6, 2022

The functional approach that you link to is discussed a bit in FluxML/Optimisers.jl#49. The final solution probably will be something like your description—walking the model and marking certain leaf nodes as frozen, either in an auxiliary structure or returning a wrapped one.

I think we have to support some way to do this after construction, since there are many cases where you don't have access to the model's construction. Depending on how freezing lands with the functional approach, I see two ways of offering freezing at construction:

  1. Use of a @freeze macro that allows you to define a model and mark parts as frozen, but under the hood it does the freezing post-construction
  2. Do something like your suggestion by allowing the wrapper described above to be used at construction as well

Even if we don't include this in Flux, I think it's worth documenting so that other users can benefit.

@jondeuce
Copy link
Contributor Author

jondeuce commented Apr 6, 2022

Thanks for the link, that's an interesting approach. And good point re: often not having access to model construction, I certainly agree that having a way to freeze layers post-construction is useful.

I do feel that semantically, a layer being frozen should probably be thought of as an attribute of the model layer, as opposed to an attribute of e.g. the optimizer state. Though of course this is debatable, since "being frozen" really only makes sense in a training context (the whole model is "frozen" during inference, after all).

  1. Use of a @freeze macro that allows you to define a model and mark parts as frozen, but under the hood it does the freezing post-construction

This is intriguing, but would you not still have to store some information about which layers are frozen into the model itself? E.g. by having @freeze Dense(2,4) return Frozen(Dense(2,4)) or similar.

@darsnack
Copy link
Member

darsnack commented Apr 6, 2022

This is intriguing, but would you not still have to store some information about which layers are frozen into the model itself? E.g. by having @freeze Dense(2,4) return Frozen(Dense(2,4)) or similar.

Yes, if the way we do freezing is through some wrapper on the model itself. And in that case, the macro would end up just being more complicated than writing Frozen(Dense(2, 4)).

If the way we do freezing is through auxiliary information passed to the optimizer, then this macro would be a convenient way to offer freezing syntax on construction. Of course, we want to minimize the amount of "stuff" that you need to pass around, so the final ergonomics will be one thing we'll consider as we evaluate our options.

@mcabbott
Copy link
Member

mcabbott commented Apr 7, 2022

To permanently freeze something, I think it should be enough just to exclude all fields from trainable. (Edit -- as you said!)

So the Frozen layer is as simple as this:

struct Frozen{T}; layer::T; end
Flux.@functor Frozen
Flux.trainable(f::Frozen) = NamedTuple()
(f::Frozen)(x) = f.layer(x)
julia> m = Chain(Frozen(Dense([2.0;;])), Dense([3.0;;]));

julia> s = Optimisers.setup(Optimisers.Descent(), m)
(layers = ((layer = nothing,), (weight = Leaf(Descent{Float32}(0.1), nothing), bias = Leaf(Descent{Float32}(0.1), nothing), σ = nothing)),)

julia> g = gradient(m -> sum(m([1])), m)[1]
(layers = ((layer = (weight = [3.0;;], bias = [3.0], σ = nothing),), (weight = [2.0;;], bias = Fill(1.0, 1), σ = nothing)),)

julia> s2, m2 = Optimisers.update(s, m, g);

julia> m2[1].layer.weight  # unchanged
1×1 Matrix{Float64}:
 2.0

julia> m2[2].weight
1×1 Matrix{Float64}:
 2.7999999970197678

@jondeuce
Copy link
Contributor Author

jondeuce commented Apr 7, 2022

To permanently freeze something, I think it should be enough just to exclude all fields from trainable

This was what I tried initially. It actually fails the tests in my original post (see the "Examples/tests" dropdown). The reason is that non-zero gradients would still be computed for frozen layers, when ideally gradient should return nothing for these layers. For example, this gradient really should be just nothing, as it would be for any other constant function:

julia> m = Frozen(Dense([2.0;;]));

julia> g = gradient(m -> sum(m([1])), m)[1]
(layers = (weight = [2.0;;], bias = Fill(1.0, 1), σ = nothing),)

If you use something like whitebox_apply, the gradient is just nothing.

Maybe it wouldn't be an issue in practice if Optimisers.update handles it correctly by skipping the frozen layers via trainable, but still the wrong gradient is being returned.

@mcabbott
Copy link
Member

mcabbott commented Apr 7, 2022

Oh I'm sorry, I clearly didn't read the whole thing.

What's gained by demanding a simpler gradient? The same rrules will often compute both one you need and one you don't, even if you tell Zygote to throw some away. When Zygote learns to handle thunks, then it may delay these calculations. Is it obvious that whitebox_apply is going to be faster?

Placing several steps inside the frozen branch, here's what I see:

julia> m = Chain(Frozen(Chain(Dense(100 => 100, relu), Dense(100 => 100, relu))), Dense(100 => 100)); x = rand(Float32, 100, 100);

julia> (l::Frozen)(x) = l.layer(x);  Zygote.refresh();

julia> @btime gradient(m -> sum(abs, m($x)), $m)[1];
  121.500 μs (144 allocations: 912.08 KiB)

julia> (l::Frozen)(x) = whitebox_apply(l.layer, x);   Zygote.refresh();

julia> @btime gradient(m -> sum(abs, m($x)), $m)[1];
  120.083 μs (133 allocations: 911.34 KiB)

julia> @btime $x * $x;
  6.517 μs (2 allocations: 39.11 KiB)

@ToucheSir
Copy link
Member

Assuming we get thunk support soonish (though having it at the top level might be more challenging), I suppose it comes down to a more philosophical discussion around how one envisions the "training loop" to work. e.g. JAX frameworks using https://github.com/deepmind/optax follow a similar model as the linked Optimisers PR: that is, freezing parameters means not propagating their gradients in the optimizer step.

In conrast, the proposal here imagines gradient flow not permeating the AD boundary at all. Both are valid and I don't think it's a matter of either-or. However, we've purposefully held off from introducing these kinds of non-control flow "utility" layers in Flux for a multitude of reasons, and I'm hesitant to start now unless the need is so overwhelming (e.g. there's no way to do this in "userspace" and it consitutes a plurality of recent feedback comments about Flux) and we're unable to provide a compelling alternative in a reasonable amount of time.

@jondeuce
Copy link
Contributor Author

jondeuce commented Apr 7, 2022

@mcabbott I actually haven't thought about whether this would be faster. I imagine you are correct that it wouldn't be with correct optimizations. In fact, probably the various partials are required to differentiate through the layer, anyways.

My point is more about correctness. I think that for consistency, a model m::Frozen should behave as similarly as possible to a generic non-parametric m::Function. For example,

julia> m = abs2; # arbitrary function

julia> gradient(m -> sum(m, [1.0]), m)[1] === nothing
true

julia> m = Frozen(Dense(2,4)); # frozen model

julia> gradient(m -> sum(m, [1.0]), m)[1] === nothing
false # true if using e.g. `whitebox_apply`

@ToucheSir This is a completely valid reason, of course. My argument for including a layer like this would essentially be for user convenience, as well as the fact that it's somewhat tricky to implement on your own and handle edge cases. For example, how should tied weights figure into this conversation? E.g. what should the behaviour be if you have both a layer l and Frozen(l) in the same model (perhaps in a Siamese network type architecture)? I'm not sure... but it may be worth figuring out and then providing this layer to users.

@ToucheSir
Copy link
Member

Tied weights are a problem for all container layers, so in that sense I don't think there's much we could do for a specific layer like Frozen.

As for the value of the layer itself as a standard user convenience, I think it'd be good to get the opinion of one or more of @darsnack, @CarloLucibello or @lorenzoh since they've done a lot more work in the trenches with non-trivial Flux models than I have.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants