-
Notifications
You must be signed in to change notification settings - Fork 41
Add a lot of optimizers including optax support to sgd
#2041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 3.72 % | 3.926e+03 | 4.072e+03 | 146.23 | 41.70 | 36.97 |
test_proximal_jac_w7x_with_eq_update | 2.06 % | 6.408e+03 | 6.540e+03 | 131.87 | 167.79 | 165.85 |
test_proximal_freeb_jac | -0.03 % | 1.315e+04 | 1.315e+04 | -4.59 | 86.15 | 85.46 |
test_proximal_freeb_jac_blocked | 0.25 % | 7.507e+03 | 7.526e+03 | 19.08 | 76.44 | 75.84 |
test_proximal_freeb_jac_batched | 0.24 % | 7.484e+03 | 7.502e+03 | 17.98 | 74.72 | 74.01 |
test_proximal_jac_ripple | -2.75 % | 3.578e+03 | 3.480e+03 | -98.41 | 69.81 | 66.51 |
test_proximal_jac_ripple_bounce1d | 0.63 % | 3.609e+03 | 3.631e+03 | 22.71 | 78.67 | 77.96 |
test_eq_solve | 0.90 % | 2.029e+03 | 2.047e+03 | 18.23 | 95.50 | 94.28 |For the memory plots, go to the summary of |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2041 +/- ##
=======================================
Coverage 94.52% 94.53%
=======================================
Files 102 102
Lines 28712 28750 +38
=======================================
+ Hits 27141 27178 +37
- Misses 1571 1572 +1
🚀 New features to boost your workflow:
|
dpanici
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just the small docstring fix, should be explitict that x_scale='"auto"` does no scaling here
f0uriest
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd double check that the x_scale logic is correct
Also, did you look at whether we could just wrap stuff from optax?
From the examples eg https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adam
it looks like the user could just pass in an optax.solver and then we can just do
opt_state = solver.init(x0)
...
g = grad(x)*x_scale
updates, opt_state = solver.update(g, opt_state, x)
x = optax.apply_updates(x, x_scale*updates)or something similar. That would give users access to a much wider array of first order optimizers, and save us having to do it all ourselves
desc/optimize/stochastic.py
Outdated
|
|
||
|
|
||
| def sgd( | ||
| def generic_sgd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgd is technically public (https://desc-docs.readthedocs.io/en/stable/_api/optimize/desc.optimize.sgd.html#desc.optimize.sgd) so if we want to change the name we should keep an alias to the old one with a deprecation warning. That said, I'm not sure we really need to change the name. "SGD" is already used pretty generically in the ML community for a bunch of first order stochastic methods like ADAM, ADAGRAD, RMSPROP, etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, SGD is in fact the general name. I can revert to the old name, and just emphasize that "sgd" option is with nesterov momentum. I was trying to make a distinction I guess
desc/optimize/stochastic.py
Outdated
| for the update rule chosen. | ||
|
|
||
| - ``"alpha"`` : (float > 0) Learning rate. Defaults to | ||
| 1e-1 * ||x_scaled|| / ||g_scaled||. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems pretty large (steps would be 10% of x), have you checked how robust this is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to solve eq with these and even though none of them converged 10% was better for a variety of equilibrium. I haven't checked other optimization problems. Reverted the change and added a backguard against 0 and NaNs.
desc/optimize/stochastic.py
Outdated
| Where alpha is the step size and beta is the momentum parameter. | ||
| Update rule for ``'sgd'``: | ||
|
|
||
| .. math:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally I prefer unicode for stuff like this. TeX looks nice in the rendered html docs, but is much harder to read as code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mostly agree and don't have a strong stance either way. My general preference is to use LaTeX for complex equations or public-facing objectives that users will first encounter in the documentation (like guiding center equations, optimization algorithms). For internal development notes or specific compute functions that aren't usually viewed on the web, I’m fine with Unicode since it keeps the source code more readable.
Also |
optax support to sgd
| name="sgd", | ||
| description="Stochastic gradient descent with Nesterov momentum" | ||
| + "See https://desc-docs.readthedocs.io/en/stable/_api/optimize/desc.optimize.sgd.html", # noqa: E501 | ||
| name=["sgd", "optax-custom"] + ["optax-" + opt for opt in _all_optax_optimizers], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we planning on deprecating sgd in favor of optax-sgd?
desc/optimize/stochastic.py
Outdated
| return result | ||
|
|
||
|
|
||
| def _sgd(g, v, alpha, beta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the same as some version of optax-sgd? if so I'd vote to remove this and just do something like if method == "sgd": method = "optax-sgd". Then we can simplify a lot of the code here and just always assume we're using optax stuff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not 100% sure but it can be equivalent to optax-sgd with momentum=beta, learning_rate=alpha and nesterov=True. The amount of code dedicated to that is not much and it also handles future implementations (I don't know if anyone wants to add their sgd optimizers but anyway). If people want to I can add depreciation but it doesn't seem urgent to me.
|
A simple helper test to check @pytest.mark.unit
def test_available_optax_optimizers(self):
"""Test that all optax optimizers are included in _all_optax_optimizers."""
optimizers = []
# Optax doesn't have a specific module for optimizers, and there is no specific
# base class for optimizers, so we have to manually exclude some outliers. The
# class optax.GradientTransformationExtraArgs is the closest thing, but there
# are some other classes that inherit from it that are not optimizers. Since
# the optimizers are actually a function that returns an instance of
# optax.GradientTransformationExtraArgs,
names_to_exclude = [
"GradientTransformationExtraArgs",
"freeze",
"scale_by_backtracking_linesearch",
"scale_by_polyak",
"scale_by_zoom_linesearch",
"optimistic_adam", # deprecated
]
for name, obj in inspect.getmembers(optax):
if name.startswith("_"):
continue
if callable(obj):
try:
sig = inspect.signature(obj)
ins = {
p.name: 0.1
for p in sig.parameters.values()
if p.default is inspect._empty
}
if name == "noisy_sgd":
ins["key"] = 0
out = obj(**ins)
if isinstance(out, optax.GradientTransformationExtraArgs):
if name not in names_to_exclude:
optimizers.append(name)
except Exception:
print(f"Could not instantiate: {name}")
pass
msg = (
"Wrapped optax optimizers can be out of date. If the newly added callable "
"is not an optimizer, add it to the names_to_exclude list in this test."
)
print(optimizers)
assert len(set(optimizers)) == len(_all_optax_optimizers), msg
assert sorted(set(optimizers)) == sorted(_all_optax_optimizers), msg
assert len(set(_all_optax_optimizers)) == len(_all_optax_optimizers), msg |
|
We are in favor of removing sgd code and have sgd alias to optax-sgd |
x_scaleis now used with SGD methods toooptaxoptimizers, and they can be called byoptax-name