Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement APIs of freeze parameters and freeze layers #1101

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,32 @@ During training, the gradients will only be computed for (and applied to) the la
Flux.params(m[1], m[3:end])
```

Sometimes, a more fine-tuned control is needed.
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
Sometimes, a more fine-tuned control is needed.
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
by simply deleting it from `ps`:

```julia
ps = params(m)
delete!(ps, m[2].b)
delete!(ps, m[2].b)
```

The `freezelayers!` function prevents parameters of multiple layers from being updated. The following example stops the parameters of the first and third layer from being updated.

```julia
m = Chain(
Dense(4, 2),
Dense(2, 3),
Dense(3, 4),
Dense(4, 3),
softmax
)

freezed_layer_indexes = [1, 3]

train!(
(x, y) -> crossentropy(m(x), y),
freezelayers!(params(m), m, freezed_layer_indexes),
[(rand(4, 10), rand(1, 10))],
ADAM(0.005)
)
```
2 changes: 1 addition & 1 deletion docs/src/training/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Training data can be conveniently partitioned for mini-batch training using the
```julia
X = rand(28, 28, 60000)
Y = rand(0:9, 60000)
data = DataLoader(X, Y, batchsize=128)
data = DataLoader(X, Y, batchsize=128)
```

Note that, by default, `train!` only loops over the data once (a single "epoch").
Expand Down
4 changes: 2 additions & 2 deletions src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module Optimise

export train!, update!,
export train!, update!, freezelayers!,
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser

include("optimisers.jl")
Expand Down
21 changes: 17 additions & 4 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import Zygote: Params, gradient
update!(opt, p, g)
update!(opt, ps::Params, gs)

Perform an update step of the parameters `ps` (or the single parameter `p`)
Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt` and the gradients `gs` (the gradient `g`).

As a result, the parameters are mutated and the optimizer's internal state may change.
As a result, the parameters are mutated and the optimizer's internal state may change.

update!(x, x̄)

Update the array `x` according to `x .-= x̄`.
"""
function update!(x::AbstractArray, x̄)
Expand All @@ -30,6 +30,19 @@ function update!(opt, xs::Params, gs)
end
end

# Fix layers
function freezelayers!(ps::Params, m, layer_indexes::Vector{Int64})
for layer_index in layer_indexes
layer = m.layers[layer_index]
param_names = fieldnames(typeof(layer))
for param_name in param_names
delete!(ps, getfield(layer, param_name))
end
end

return ps
end

# Callback niceties
call(f, xs...) = f(xs...)
runall(f) = f
Expand Down Expand Up @@ -61,7 +74,7 @@ end
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`.

In case datapoints `d` are of numeric array type, assumes no splatting is needed
In case datapoints `d` are of numeric array type, assumes no splatting is needed
and computes the gradient of `loss(d)`.

Takes a callback as keyword argument `cb`. For example, this will print "training"
Expand Down
51 changes: 51 additions & 0 deletions test/freezelayers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using Flux: Chain, Dense, softmax, params, train!
using Flux: crossentropy, ADAM, freezelayers!
using Test

function collect_params(m, layer_indexes::Vector{Int64})
ps = []
for layer_index in layer_indexes
layer = m.layers[layer_index]
param_names = fieldnames(typeof(layer))
for param_name in param_names
p = deepcopy(getfield(layer, param_name))
push!(ps, p)
end
end

return ps
end

@testset "FreezeLayers" begin
# Fixed layers
freezed_layer_indexes = [1, 3]
mutable_layer_indexes = [2, 4]

# Model
m = Chain(
Dense(4, 2),
Dense(2, 3),
Dense(3, 4),
Dense(4, 3),
softmax
)

# Original params
ps_freezed = collect_params(m, freezed_layer_indexes)
ps_mutable = collect_params(m, mutable_layer_indexes)

# Update params
train!(
(x, y) -> crossentropy(m(x), y),
freezelayers!(params(m), m, freezed_layer_indexes),
[(rand(4, 10), rand(1, 10))],
ADAM(0.005)
)

# Params after update
new_ps_freezed = collect_params(m, freezed_layer_indexes)
new_ps_mutable = collect_params(m, mutable_layer_indexes)

@test new_ps_freezed == ps_freezed
@test new_ps_mutable != ps_mutable
end
8 changes: 6 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Flux
using Flux
using Flux.Data
using Test
using Test
using Random, Statistics, LinearAlgebra
using Documenter
using IterTools: ncycle
Expand All @@ -21,6 +21,10 @@ Random.seed!(0)
include("optimise.jl")
end

@testset "FreezeLayers" begin
include("freezelayers.jl")
end

@testset "Data" begin
include("data.jl")
end
Expand Down