Replies: 12 comments 39 replies
-
I guess possible change is that instead of the scheduler's |
Beta Was this translation helpful? Give feedback.
-
I think the proposed approach exists more or less wholesale in Optimisers.jl already. Some points of interest:
So why not use Optimisers.jl now and be done with it?
Finally, there is also the more philosophical question of what Optimisers.jl is trying to be. As you noted, Optax is clearly more of a "gradient transformation" library, but what about gradient-free optimizers in Optim.jl? I'm of the opinion that the design space and tradeoffs change quite a bit if we're treating Optimisers.jl as a gradient transformation library, a general optimizer interface, some intersection of the two or something else entirely. |
Beta Was this translation helpful? Give feedback.
-
GalacticOptim.jl is already a universal optimization interface. Why build two? Optimisers.jl was just to move the optimizers out of Julia IIRC. |
Beta Was this translation helpful? Give feedback.
-
Coming to GalacticOptim hopefully soon.
GalacticOptim uses all explicit state since implicit state cannot be compatible with non-Julia tools. Its functions have a |
Beta Was this translation helpful? Give feedback.
-
I'm trying to find a second to work on SciMLInterface, which will just be a simple package that exports all of the problem and solution types in SciML and the |
Beta Was this translation helpful? Give feedback.
-
As a more general note @ToucheSir this is basically a draft of what I would PR to Optimisers.jl. I think that has the foundation already laid, and this is just a view of the eventual final product. |
Beta Was this translation helpful? Give feedback.
-
Interesting development: PyTorch is now adding more functional optimizers for distributed training: https://github.com/pytorch/pytorch/blob/master/torch/distributed/optim/functional_adagrad.py. I haven't looked at the interface or implementation in detail, but at least superficially it bears a large resemblance to what has been proposed here. |
Beta Was this translation helpful? Give feedback.
-
Functional seems to be the way to go in newer problem domains. We shouldn't worry about diverging from PyTorch and their interfaces etc, this isn't Torch.jl (pun intended) Optimisers.jl has most of the optimisers now (modulo some ADAM derivatives - we have a format to copy paste the code basically), and the schedulers PR seems to match many of the different assumptions in this proposal as well. So if you can imagine changing the behaviour of optimisers as in Optimisers.jl, using it in Flux like FluxML/Flux.jl#1481, with standard schedulers that can (be part of hooks/ be context aware), we should be good. p1(l) = insan(l) && Flux.skip()
# or
struct Losses
n
end
ls = Losses([])
p2(l) = append!(ls, l)
Flux.train(..., prehook = [p1, p2]) is one such way to look at it. Notice also that we don't limit ourselves with implicit params in the PR I mentioned. We can be compatible of course, but this a clearer interface with the simple loop. |
Beta Was this translation helpful? Give feedback.
-
I am debating whether or not to retain I would be interested to hear why |
Beta Was this translation helpful? Give feedback.
-
Maybe something like |
Beta Was this translation helpful? Give feedback.
-
Starting a new thread so we have somewhere to discuss higher-order optimizers. Xref FluxML/Optimisers.jl#4. Good prior art might be https://github.com/nestordemeure/AdaHessianJax. |
Beta Was this translation helpful? Give feedback.
-
Now that FluxML/Flux.jl#1325 has landed, thoughts on what we can do with |
Beta Was this translation helpful? Give feedback.
-
While working on ParameterSchedulers.jl, I ran into some complications implementing a scheduled optimizer. So, I decided now would be a good time to rethink Flux’s optimizer interfaces. Interestingly, the current design is not too far off from Optimisers.jl or Optax. So the goal here is to figure out what are the key differences/requirements. Since my thoughts are somewhat sporadic, I figured it would be more helpful to present the proposed approach first. Anyone curious about my reasoning can keep reading to the bottom.
Proposed approach
Like I mention below, there is no meaningful difference in my mind between Flux’s current approach and Optax. The key difference is how Optax deals with state. Since the state is explicit, the calling functions get to manage it. This is what we need to change in Flux.
One starts by defining an abstract optimizer type. What this lets us do is generically define how to handle
Params
(something Optax does by assuming the use ofjax.tree_multimap
).Here we see where the functional approach helps us by using explicit state. We can easily initialize an
IdDict
of optimizer state for each model parameter inxs
. Later, when we get toScheduledOptim
, we’ll see how this helps. Here, we also define the user-facing, high-levelupdate!
interface.update!
is how we deal with non-AbstractArray
structures.Next, I used
Momentum
as an example of an optimizer. The implementation is almost the same as Flux’s current implementation, except that the velocity is now part of the state and not the struct. I usedNamedTuple
s for the state because they are like anonymous structs which is essentially what state is. Also notice thatapply!
is still mutable (though the state is immutable), and an optimizer is implemented assuming∇x
is anAbstractArray
.Next, we define
ScheduledOptim
. I used the anonymous function approach to address which hyperparameter to set. I don’t think there’s really any compelling way to do this, and the anonymous setter function seems like the most flexible. We could still define something likelr
/lr!
for common hyperparameters, and it would compose perfectly well with this approach.Unlike Optax,
ScheduledOptim
wraps another optimizer. The approach in Optax is to definescale_by_schedule
which can be part of aOptax.chain
like:This isn’t a compelling approach to me. The way this works is that optimizers like
scale_by_adam
don’t include any learning rate at all. This is unintuitive to me, since every optimizer paper will include a notion of a learning rate. So the “solution” in Optax was to remove the one parameter that they want to schedule into its own optimizer. This basically means that none of the other hyperparameters can be scheduled.Finally, just for completeness, I define another composition like
Flux.Optimiser
orOptax.chain
.Optimizer application
Currently, Flux’s optimizers define a common function,
apply!(o, x, Δ)
, that updatesΔ
in-place to change the gradient step based onx
and the state ofo
. This approach is nice for simplicity, composition, and memory efficiency. The main issue is thatapply!
is called on each model parameter every iteration of the innermost training loop. This becomes tricky when writing schedules as “just another optimizer,” because the schedule must be set so that it accounts for how many timesapply!
is invoked over the course of training.Explicit state
Currently, the signature of
apply!
isThe state of
o
is stored internally, and it gets updated on each call toapply!
. A more functional approach would treato
as a function that augments somestate
anddx
. The signature would look likeWith this change, the updates to both
dx
andstate
bydx'
andstate'
can be deferred (ifapply
is non-mutating). It releases control of the passage and evolution of state to the calling function.Composing optimizers
If an optimizer is abstracted as a function
o(dx) -> Δdx
(i.e. the optimizer returns the change or transformation todx
, then Flux’sOptimiser
allows composition of a series of optimizers,[o1, o2, o3]
asThis is exactly what
Optax.chain
does.Another way to think of this is that optimizers are rules about gradient transformation, and compositions are rules about optimizer application. A key to making this possible is deferment. As seen above, deferment lets the calling function decide when to actually change the gradient, and the optimizer is just a rule that says how to change the gradient.
The way Optax guarantees this deferment is by each optimizer being a pair of functions: one for initialization and another for updating. My guess is that this is the cleanest way to approach this in Python + Jax since they lack multiple dispatch. With multiple dispatch, I’d argue that the current
Flux.apply!
is exactly the same as an Optaxupdate_fn
. All we need to do is separate out the initialization. I don’t even think mutability is that important here. At the end of the day, the operations will be done in-place (whether by construction like Flux or by optimization like Jax). Deferment is not guaranteed by immutability; it’s guaranteed by the transformation function (apply!
) not being called until the composition decides to.At this point, I think we should have a brief tangent on hyperparameters before returning to the overall design.
Optimizer hyperparameters
This is the issue most relevant to ParameterSchedulers.jl. Every optimizer is a rule parameterized by a collection of hyperparameter variables that are used to transform a gradient. Naturally, this is best expressed in Julia as
where
alpha
andbeta
are the hyperparameters. The issue becomes that the meaning of a hyperparameter and its access are detached. For example, ifalpha
corresponds to the learning rate (LR), and I know I am dealing with ao::OptRuleX
, then I can get and set the LR witho.alpha
. But when writing a function that operates on a generic optimizer, we don’t know what field corresponds to the LR.Potential solutions
Standardized field
This is a non-solution in my opinion, but it’s one that I’ll mention just to get passed it. The main idea is that common fields like the LR have a standard name (e.g.
eta
in the current Flux optimizers). The drawbacks are:o.eta
and it doesn’t existStandardized interface functions
This is a reasonable solution. So, to be a “Flux optimizer,” your optimizer struct should implement a standard interface function like
lr
/lr!
. For example, we could writeThis approach addresses the first two problems with previous solution. But it still doesn’t address the last issue of uncommon hyperparameters. A standard interface in Julia is only useful when its defined in a single place for other packages to extend. This means that any time someone wants a new optimizer to compose well with current and future optimizers, they need to submit a PR to the common base to have their hyperparameter function added. So, an interface function only seems reasonable so long as the number of functions doesn’t need to constantly increase.
Anonymous getter/setter functions
This approach is the one initially used by ParameterSchedulers.jl. It completely ignores the problem and leaves the solution up to downstream packages. For example, a scheduling package might accept an anonymous function as an argument that tells it how to set the optimizer parameter to the latest value.
ScheduledOptim.update_func
is a field that stores this function, and theScheduledOptim
uses it to set the wrapped optimizer parameter. The main advantage of this approach is that it is the most transparent to the user.Hyperparameters are types
This is the approach used in FluxTraining.jl. The idea is that the type acts like a pseudo-reference to the hyperparameter field. For example, instead of passing
o.eta
to a scheduler, you would pass::Type{LearningRate}
. Then the scheduler callssetparameter!(o, ::Type{LearningRate}, val)
, and the optimizer can extend thissetparameter!
so thato.eta
is updated. This suffers from similar issues to the standard interface, though it is slightly less cumbersome to extend in practice.Beta Was this translation helpful? Give feedback.
All reactions