Skip to content

Use eltype(x) everywhere, ignore typeof(η) #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
version = "0.2.20"
version = "0.3.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
7 changes: 4 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ These act on one array of parameters:

```julia
# Define a container to hold any optimiser specific parameters (if any):
struct DecayDescent{T} <: Optimisers.AbstractRule
eta::T
struct DecayDescent <: Optimisers.AbstractRule
eta::Float64
end

# Define an `apply!` rule which encodes how the gradients will be used to
# update the parameters:
function Optimisers.apply!(o::DecayDescent, state, x, x̄)
newx̄ = (o.eta / √state) .* x̄
T = eltype(x)
newx̄ = T(o.eta / √state) .* x̄
nextstate = state + 1
return nextstate, newx̄
end
Expand Down
14 changes: 7 additions & 7 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ or [`update!`](@ref).
julia> m = (x = rand(3), y = (true, false), z = tanh);

julia> Optimisers.setup(Momentum(), m) # same field names as m
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
(x = Leaf(Momentum(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
```

The recursion into structures uses Functors.jl, and any new `struct`s containing parameters
Expand All @@ -91,15 +91,15 @@ julia> struct Layer; mat; fun; end
julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);

julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
(lay = (), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
(lay = (), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))

julia> destructure(model)
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))

julia> using Functors; @functor Layer # annotate this type as containing parameters

julia> Optimisers.setup(Momentum(), model)
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
(lay = (mat = Leaf(Momentum(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))

julia> destructure(model)
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
Expand All @@ -120,13 +120,13 @@ See also [`update!`](@ref), which will be faster for models of ordinary `Array`s
```jldoctest
julia> m = (x = Float32[1,2,3], y = tanh);

julia> t = Optimisers.setup(Descent(0.1f0), m)
(x = Leaf(Descent{Float32}(0.1), nothing), y = ())
julia> t = Optimisers.setup(Descent(0.1), m)
(x = Leaf(Descent(0.1), nothing), y = ())

julia> g = (x = [1,1,1], y = nothing); # fake gradient

julia> Optimisers.update(t, m, g)
((x = Leaf(Descent{Float32}(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
((x = Leaf(Descent(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
```
"""
update
Expand All @@ -152,7 +152,7 @@ julia> using StaticArrays, Zygote, Optimisers
julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model

julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
(x = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]))

julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
Expand Down
18 changes: 9 additions & 9 deletions src/adjust.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ julia> Optimisers.freeze!(s.x)
julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient

julia> m
(x = ([1.0], 2.0), y = [-0.14159258336972558])
(x = ([1.0], 2.0), y = [-0.14159265358979312])

julia> s
(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159]))
(x = (Leaf(Momentum(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum(0.01, 0.9), [3.14159]))

julia> Optimisers.thaw!(s)

julia> s.x
(Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), ())
(Leaf(Momentum(0.01, 0.9), [0.0]), ())
```
"""
freeze!(tree) = foreach(freeze!, tree)
Expand Down Expand Up @@ -72,17 +72,17 @@ To change just the learning rate, provide a number `η::Real`.
julia> m = (vec = rand(Float32, 2), fun = sin);

julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
(vec = Leaf(Nesterov(0.001, 0.9), Float32[0.0, 0.0]), fun = ())

julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient

julia> st
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
(vec = Leaf(Nesterov(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())

julia> Optimisers.adjust!(st, 0.123) # change learning rate, stored momentum untouched

julia> st
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
```

To change other parameters, `adjust!` also accepts keyword arguments matching the field
Expand All @@ -93,13 +93,13 @@ julia> fieldnames(Adam)
(:eta, :beta, :epsilon)

julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
(vec = Leaf(OptimiserChain(ClipGrad(10.0), Adam(0.001, (0.9, 0.999), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())

julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts on ClipGrad
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
(vec = Leaf(OptimiserChain(ClipGrad(11.1), Adam(0.001, (0.777, 0.909), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())

julia> Optimisers.adjust(st; beta = "no such field") # silently ignored!
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
```
"""
adjust!(tree, eta::Real) = foreach(st -> adjust!(st, eta), tree)
Expand Down
43 changes: 43 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}

abstract type AbstractRule end

function Base.show(io::IO, rule::AbstractRule) # makes Adam(0.01f0) prettier
invoke(show, Tuple{IO,Any}, IOContext(io, :compact => true), rule)
end

###
### setup
###
Expand Down Expand Up @@ -225,3 +229,42 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)
onevalue(λ::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)

nonneg(η::Real) = η < 0 ? throw(DomainError(η, "the learning rate cannot be negative")) : η

"""
@def struct Rule; eta = 0.1; beta = (0.7, 0.8); end

Helper macro for defining rules with default values.
The types of the literal values are used in the `struct`,
like this:
```
struct Rule
eta::Float64
beta::Tuple{Float64, Float64}
Rule(eta = 0.1, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
end
```
Any field called `eta` is assumed to be a learning rate, and cannot be negative.
"""
macro def(expr)
Meta.isexpr(expr, :struct) || throw("@def must act on a struct definition")
lines = expr.args[3].args
names, vals = [], []
for i in eachindex(lines)
lines[i] isa Symbol && throw("@def requires a default for every field")
Meta.isexpr(lines[i], :(=)) || continue
name, val = lines[i].args
push!(names, name)
push!(vals, val)
lines[i] = :($name::$typeof($val))
end
rule = Meta.isexpr(expr.args[2], :<:) ? expr.args[2].args[1] : expr.args[2]
check = :eta in names ? :(eta < 0 && throw(DomainError(eta, "the learning rate cannot be negative"))) : nothing
inner = :(function $rule($([Expr(:kw, nv...) for nv in zip(names,vals)]...))
$check
new($(names...))
end)
push!(lines, inner)
esc(expr)
end

Loading