Skip to content

Per-leaf freezing #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Optimisers.OptimiserChain
Optimisers.setup
Optimisers.update
Optimisers.update!
Optimisers.freeze
Optimisers.thaw
```

Calling `Functors.@functor` on your model's layer types by default causes the
Expand Down
22 changes: 22 additions & 0 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Optimisers

using Functors: functor, fmap, isleaf
using LinearAlgebra
using Base: tail

include("interface.jl")
include("rules.jl")
Expand Down Expand Up @@ -103,4 +104,25 @@ arrays within the old model (and the old state), it will be faster for models of
"""
update!

"""
Optimisers.freeze(tree, branches) -> tree

Disable training of part of the model, by modifying the optimiser states
returned by [`setup`](@ref). Which branches to alter is specified by:
* a symbol `:encoder` to shield all nodes within `model.encoder` from `update`,
* a tuple `(:layers, 1, :enc, 3)` to fix all nodes within `model.layers[1].enc[3]`, and
* a vector `[:enc, (:dec, 2), (:dec, 3)]` to act on all the given parts.

The reverse is [`thaw`](@ref), which by default acts on all nodes.
"""
freeze

"""
Optimisers.thaw(tree, branches = ()) -> tree

Removes the restrictions placed by [`freeze`](@ref). By default walks over the complete
tree of optimisers states, but can also be applied to only some branches.
"""
thaw

end # module
35 changes: 31 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}
struct Leaf{R,S}
rule::R
state::S
frozen::Bool
end

function setup(rule, x; seen = Base.IdSet())
if isnumeric(x)
x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry."))
isbits(x) || push!(seen, x)
return Leaf(rule, init(rule, x))
return Leaf(rule, init(rule, x), false)
elseif isleaf(x)
return nothing
else
Expand All @@ -28,8 +29,9 @@ update!(::Nothing, x, x̄s...) = nothing, x

update!(ℓ::Leaf, x, ::Zero...) = ℓ, x
function update!(ℓ::Leaf, x, x̄s...)
ℓ.frozen && return ℓ, x
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...)
Leaf(ℓ.rule, s′), subtract!(x, x̄′)
Leaf(ℓ.rule, s′, ℓ.frozen), subtract!(x, x̄′)
end

update!(tree, x, ::Zero...) = tree, x
Expand All @@ -56,6 +58,32 @@ isnumeric(x) = false
iswriteable(::DenseArray{<:AbstractFloat}) = true # more elaborate versions are possible, wait until needed?
iswriteable(_) = false

ids(x::NamedTuple{names}) where names = NamedTuple{names}(names) # a map-friendly version of pairs
ids(x::Tuple) = propertynames(x)

freeze(ℓ::Leaf, addr::Tuple{}, b=true) = Leaf(ℓ.rule, ℓ.state, b)
freeze(::Nothing, addr::Tuple{}, b=true) = nothing
freeze(::Union{Leaf, Nothing}, addr::Tuple, b=true) = error("invalid index $(repr(addr[1])) at leaf node")

function freeze(tree, addr::Tuple, b=true)
isleaf(tree) && return error("Expected Leaf or Nothing, this is not a valid state tree. Perhaps you called freeze on the model?")
isempty(addr) && return map(t -> freeze(t, addr, b), tree)
addr[1] in ids(tree) || error("invalid index $(repr(addr[1])) ∉ $(propertynames(tree))")
map((t,i) -> i==addr[1] ? freeze(t, tail(addr), b) : t, tree, ids(tree))
end

# freeze(t::Tied, addr::Tuple, b=true) = Tied(t.ties, freeze(t.tree, addr, b)) # for PR42

freeze(tree, addr::Union{Integer, Symbol}, b=true) = freeze(tree, (addr,), b)
function freeze(tree, addr::Vector, b=true)
for a in addr
tree = freeze(tree, a, b)
end
tree
end

thaw(tree, addr=()) = freeze(tree, addr, false)

"""
trainable(x::Layer) -> NamedTuple

Expand Down Expand Up @@ -118,6 +146,5 @@ function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long
ioc = IOContext(io, :compact => true)
print(ioc, "Leaf(", ℓ.rule, ", ")
show(ioc, ℓ.state)
print(io, ")")
print(io, ", ", ℓ.frozen, ")")
end

17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,23 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
@test_throws ArgumentError Optimisers.setup(ADAMW(), m2)
end

@testset "freeze/thaw" begin
m = (x=[1,2], y=([3,4], sin));
st = Optimisers.setup(Descent(0.1), m);
st = Optimisers.freeze(st, :y)
st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing)));
@test m.x ≈ [0.9, 1.0]
@test m.y[1] == [3, 4]
st = Optimisers.thaw(st)
st, m = Optimisers.update(st, m, (x=[1,10], y=([100,1000], nothing)));
@test m.y[1] ≈ [-7.0, -96.0]
@test Optimisers.freeze(st, :y) == Optimisers.freeze(st, (:y, 1))

@test_throws Exception Optimisers.freeze(st, :z) # no such key
@test_throws Exception Optimisers.freeze(st, (:x, 1)) # too long
@test_throws Exception Optimisers.freeze(m, :x) # model not state
end

@info "finished feature testing"
end
@testset verbose=true "Optimisation Rules" begin
Expand Down