Skip to content

Conversation

@YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented Dec 16, 2025

  • Unifies all the SGD type optimizers, developers only need to implement the update rule
  • x_scale is now used with SGD methods too
  • Adds wrappers for optax optimizers, and they can be called by optax-name
  • Any custom optax optimizer can be used via
        import optax
        from desc.optimize import Optimizer
        from desc.examples import get

        eq = get("DSHAPE")

        # Optimizer
        opt = optax.chain(
            optax.sgd(learning_rate=1.0),
            optax.scale_by_zoom_linesearch(max_linesearch_steps=15),
        )
        optimizer = Optimizer("optax-custom")
        eq.solve(optimizer=optimizer, options={"optax-options": {"update_rule": opt}})

@YigitElma YigitElma self-assigned this Dec 16, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Dec 16, 2025

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    3.72 %    |     3.926e+03      |     4.072e+03      |    146.23    |       41.70        |       36.97        |
  test_proximal_jac_w7x_with_eq_update   |    2.06 %    |     6.408e+03      |     6.540e+03      |    131.87    |       167.79       |       165.85       |
  test_proximal_freeb_jac                |   -0.03 %    |     1.315e+04      |     1.315e+04      |    -4.59     |       86.15        |       85.46        |
  test_proximal_freeb_jac_blocked        |    0.25 %    |     7.507e+03      |     7.526e+03      |    19.08     |       76.44        |       75.84        |
  test_proximal_freeb_jac_batched        |    0.24 %    |     7.484e+03      |     7.502e+03      |    17.98     |       74.72        |       74.01        |
  test_proximal_jac_ripple               |   -2.75 %    |     3.578e+03      |     3.480e+03      |    -98.41    |       69.81        |       66.51        |
  test_proximal_jac_ripple_bounce1d      |    0.63 %    |     3.609e+03      |     3.631e+03      |    22.71     |       78.67        |       77.96        |
  test_eq_solve                          |    0.90 %    |     2.029e+03      |     2.047e+03      |    18.23     |       95.50        |       94.28        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@codecov
Copy link

codecov bot commented Dec 16, 2025

Codecov Report

❌ Patch coverage is 98.18182% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 94.53%. Comparing base (7aa703f) to head (5843eaa).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
desc/optimize/stochastic.py 98.14% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master    #2041   +/-   ##
=======================================
  Coverage   94.52%   94.53%           
=======================================
  Files         102      102           
  Lines       28712    28750   +38     
=======================================
+ Hits        27141    27178   +37     
- Misses       1571     1572    +1     
Files with missing lines Coverage Δ
desc/optimize/_desc_wrappers.py 91.17% <100.00%> (+0.13%) ⬆️
desc/optimize/stochastic.py 98.07% <98.14%> (+1.06%) ⬆️

... and 3 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@YigitElma YigitElma marked this pull request as ready for review December 18, 2025 01:42
@YigitElma YigitElma requested review from a team, ddudt, dpanici, f0uriest, rahulgaur104 and unalmis and removed request for a team December 18, 2025 01:42
@YigitElma YigitElma changed the title Add ADAM optimizer Add ADAM and RMSProp optimizers Dec 18, 2025
Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just the small docstring fix, should be explitict that x_scale='"auto"` does no scaling here

dpanici
dpanici previously approved these changes Dec 22, 2025
Copy link
Member

@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd double check that the x_scale logic is correct

Also, did you look at whether we could just wrap stuff from optax?

From the examples eg https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adam

it looks like the user could just pass in an optax.solver and then we can just do

opt_state = solver.init(x0)
...
g = grad(x)*x_scale
updates, opt_state = solver.update(g, opt_state, x)
x = optax.apply_updates(x, x_scale*updates)

or something similar. That would give users access to a much wider array of first order optimizers, and save us having to do it all ourselves



def sgd(
def generic_sgd(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgd is technically public (https://desc-docs.readthedocs.io/en/stable/_api/optimize/desc.optimize.sgd.html#desc.optimize.sgd) so if we want to change the name we should keep an alias to the old one with a deprecation warning. That said, I'm not sure we really need to change the name. "SGD" is already used pretty generically in the ML community for a bunch of first order stochastic methods like ADAM, ADAGRAD, RMSPROP, etc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, SGD is in fact the general name. I can revert to the old name, and just emphasize that "sgd" option is with nesterov momentum. I was trying to make a distinction I guess

for the update rule chosen.

- ``"alpha"`` : (float > 0) Learning rate. Defaults to
1e-1 * ||x_scaled|| / ||g_scaled||.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems pretty large (steps would be 10% of x), have you checked how robust this is?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to solve eq with these and even though none of them converged 10% was better for a variety of equilibrium. I haven't checked other optimization problems. Reverted the change and added a backguard against 0 and NaNs.

Where alpha is the step size and beta is the momentum parameter.
Update rule for ``'sgd'``:

.. math::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally I prefer unicode for stuff like this. TeX looks nice in the rendered html docs, but is much harder to read as code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly agree and don't have a strong stance either way. My general preference is to use LaTeX for complex equations or public-facing objectives that users will first encounter in the documentation (like guiding center equations, optimization algorithms). For internal development notes or specific compute functions that aren't usually viewed on the web, I’m fine with Unicode since it keeps the source code more readable.

@unalmis
Copy link
Collaborator

unalmis commented Jan 6, 2026

Also, did you look at whether we could just wrap stuff from optax?

or something similar. That would give users access to a much wider array of first order optimizers, and save us having to do it all ourselves

Also optimistix has trust region methods with easy to use linear solvers.., e.g. normal conjugate gradient etc.

@YigitElma YigitElma changed the title Add ADAM and RMSProp optimizers Add a lot of optimizers including optax support to sgd Jan 7, 2026
@YigitElma YigitElma requested review from dpanici and f0uriest January 7, 2026 23:03
dpanici
dpanici previously approved these changes Jan 21, 2026
name="sgd",
description="Stochastic gradient descent with Nesterov momentum"
+ "See https://desc-docs.readthedocs.io/en/stable/_api/optimize/desc.optimize.sgd.html", # noqa: E501
name=["sgd", "optax-custom"] + ["optax-" + opt for opt in _all_optax_optimizers],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we planning on deprecating sgd in favor of optax-sgd?

return result


def _sgd(g, v, alpha, beta):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the same as some version of optax-sgd? if so I'd vote to remove this and just do something like if method == "sgd": method = "optax-sgd". Then we can simplify a lot of the code here and just always assume we're using optax stuff

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure but it can be equivalent to optax-sgd with momentum=beta, learning_rate=alpha and nesterov=True. The amount of code dedicated to that is not much and it also handles future implementations (I don't know if anyone wants to add their sgd optimizers but anyway). If people want to I can add depreciation but it doesn't seem urgent to me.

@YigitElma
Copy link
Collaborator Author

A simple helper test to check _all_optax_optimizers is outdated or not.

    @pytest.mark.unit
    def test_available_optax_optimizers(self):
        """Test that all optax optimizers are included in _all_optax_optimizers."""
        optimizers = []
        # Optax doesn't have a specific module for optimizers, and there is no specific
        # base class for optimizers, so we have to manually exclude some outliers. The
        # class optax.GradientTransformationExtraArgs is the closest thing, but there
        # are some other classes that inherit from it that are not optimizers. Since
        # the optimizers are actually a function that returns an instance of
        # optax.GradientTransformationExtraArgs,
        names_to_exclude = [
            "GradientTransformationExtraArgs",
            "freeze",
            "scale_by_backtracking_linesearch",
            "scale_by_polyak",
            "scale_by_zoom_linesearch",
            "optimistic_adam",  # deprecated
        ]
        for name, obj in inspect.getmembers(optax):
            if name.startswith("_"):
                continue
            if callable(obj):
                try:
                    sig = inspect.signature(obj)
                    ins = {
                        p.name: 0.1
                        for p in sig.parameters.values()
                        if p.default is inspect._empty
                    }
                    if name == "noisy_sgd":
                        ins["key"] = 0
                    out = obj(**ins)
                    if isinstance(out, optax.GradientTransformationExtraArgs):
                        if name not in names_to_exclude:
                            optimizers.append(name)
                except Exception:
                    print(f"Could not instantiate: {name}")
                    pass

        msg = (
            "Wrapped optax optimizers can be out of date. If the newly added callable "
            "is not an optimizer, add it to the names_to_exclude list in this test."
        )
        print(optimizers)
        assert len(set(optimizers)) == len(_all_optax_optimizers), msg
        assert sorted(set(optimizers)) == sorted(_all_optax_optimizers), msg
        assert len(set(_all_optax_optimizers)) == len(_all_optax_optimizers), msg

@dpanici
Copy link
Collaborator

dpanici commented Jan 28, 2026

We are in favor of removing sgd code and have sgd alias to optax-sgd

@YigitElma YigitElma requested review from dpanici and f0uriest January 29, 2026 18:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants