Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
2069500
always re-solve eq in ProximalProjection before first computation of …
dpanici Aug 11, 2025
54e3c66
update changelog
dpanici Aug 11, 2025
b8e33db
Merge branch 'master' into dp/prox-always-solve-first
dpanici Aug 13, 2025
a50454f
add flag to retain old behavior for test purposes
dpanici Aug 13, 2025
8cd9ceb
Merge branch 'dp/prox-always-solve-first' of github.com:PlasmaControl…
dpanici Aug 13, 2025
1f2784a
Merge branch 'master' into dp/prox-always-solve-first
dpanici Aug 13, 2025
c53b08b
Merge branch 'master' into dp/prox-always-solve-first
YigitElma Aug 19, 2025
8988bb3
Merge branch 'master' into dp/prox-always-solve-first
dpanici Aug 19, 2025
6b9f782
Update CHANGELOG.md
dpanici Aug 19, 2025
779a61f
change name of method
dpanici Aug 20, 2025
0883e14
Merge branch 'master' into dp/prox-always-solve-first
dpanici Aug 24, 2025
a477ccf
Merge branch 'master' into dp/prox-always-solve-first
dpanici Aug 25, 2025
e9f3173
make benchmarks not re-solve eq (as they are not meant to test that p…
dpanici Aug 25, 2025
9fa7560
try to undo change to test after adding build()
dpanici Aug 25, 2025
cc0602f
fix test, cannot solve if we are checking history
dpanici Aug 25, 2025
8f6bae0
update changelog
dpanici Aug 26, 2025
25ce857
Merge branch 'master' into dp/prox-always-solve-first
YigitElma Sep 3, 2025
d5373b4
Merge branch 'master' into dp/prox-always-solve-first
dpanici Sep 3, 2025
0167467
Merge branch 'master' into dp/prox-always-solve-first
YigitElma Sep 4, 2025
20c48ef
Merge branch 'master' into dp/prox-always-solve-first
YigitElma Sep 4, 2025
7fa9267
Merge branch 'master' into dp/prox-always-solve-first
YigitElma Sep 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ Bug Fixes
---------
- [Fixes straight field line equilibrium conversion](https://github.com/PlasmaControl/DESC/pull/1880).

Backend

- When using any of the ``"proximal-"`` optimization methods, the equilbrium is now always solved before beginning optimization to the specified tolerance (as determined, for example, by ``options={"solve_options":{"ftol"...}}`` passed to the ``desc.optimize.Optimizer.optimize`` call). This ensures the assumptions of the proximal projection method are enforced starting from the first step of the optimization.


v0.15.0
-------

Expand Down Expand Up @@ -37,6 +42,7 @@ Bug Fixes
- Fixes bug in ``desc.geometry.curve.FourierRZCurve.from_values`` when numpy array is passed in for the ``coords`` argument.

Backend

- Significant changes to how DESC handles static attributes during JIT compilation. Going forward if any class/object has attributes that should be treated as static by `jax.jit`, these should be declared at the class level like `_static_attrs = ["foo", "bar"]`. Generally, non-arraylike attributes such as functions, strings etc should be marked static, as well as any attributes used for control flow. Previously this was done automatically, but in a way that caused a lot of performance bugs and unnecessary recompilation. These changes have been implemented for all classes in the `desc` repository, but if you have custom objectives or other local objects that subclass from `desc` you may need to add this yourself. JAX error messages usually do a good job of alerting you to things that need to be static, and feel free to open an issue with `desc` if you have any questions.
- No longer closes over the field in ``desc.magnetic_fields._core.field_line_integrate``, which can dramatically reduce compile times when the field being traced has large size attributes (for example, when using a ``desc.magnetic_fields._core.SplineMagneticField`` object).

Expand Down
15 changes: 14 additions & 1 deletion desc/optimize/_constraint_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ def __init__(
self._objective = objective
self._constraint = constraint
solve_options = {} if solve_options is None else solve_options
self._solve_during_proximal_build = solve_options.pop(
"solve_during_proximal_build", True
) # If user does not want the solve during build, mainly for debug purposes
perturb_options = {} if perturb_options is None else perturb_options
perturb_options.setdefault("verbose", 0)
perturb_options.setdefault("include_f", False)
Expand Down Expand Up @@ -771,7 +774,17 @@ def build(self, use_jit=None, verbose=1): # noqa: C901
[np.atleast_2d(foo) for foo in self._feasible_tangents], axis=-1
)

# history and caching
## history and caching
# first, ensure equilibrium is solved to the
# specified tolerances, necessary as we assume
# eq is solved when taking the derivatives later
if self._solve_during_proximal_build:
self._eq.solve(
objective=self._eq_solve_objective,
constraints=None,
**self._solve_options,
)
# then store the now-solved eq state as the initial state
self._x_old = self.x(self.things)
self._allx = [self._x_old]
self._allxopt = [self._objective.x(*self.things)]
Expand Down
22 changes: 17 additions & 5 deletions tests/benchmarks/benchmark_cpu_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@ def test_proximal_jac_atf(benchmark):
grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.linspace(0.1, 1, 10))
objective = ObjectiveFunction(QuasisymmetryTwoTerm(eq, grid=grid))
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
prox.build()
x = prox.x(eq)
prox.jac_scaled_error(x, prox.constants).block_until_ready()
Expand Down Expand Up @@ -361,7 +363,11 @@ def test_proximal_jac_atf_with_eq_update(benchmark):
constraint,
eq,
perturb_options={"verbose": 3},
solve_options={"verbose": 3, "maxiter": 0},
solve_options={
"verbose": 3,
"maxiter": 0,
"solve_during_proximal_build": False,
},
)
prox.build(verbose=3)
x = prox.x(eq)
Expand All @@ -388,7 +394,9 @@ def test_proximal_freeb_compute(benchmark):
field = ToroidalMagneticField(1.0, 1.0) # just a dummy field for benchmarking
objective = ObjectiveFunction(BoundaryError(eq, field=field))
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
obj = LinearConstraintProjection(
prox, ObjectiveFunction((FixCurrent(eq), FixPressure(eq), FixPsi(eq)))
)
Expand All @@ -412,7 +420,9 @@ def test_proximal_freeb_jac(benchmark):
field = ToroidalMagneticField(1.0, 1.0) # just a dummy field for benchmarking
objective = ObjectiveFunction(BoundaryError(eq, field=field))
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
obj = LinearConstraintProjection(
prox, ObjectiveFunction((FixCurrent(eq), FixPressure(eq), FixPsi(eq)))
)
Expand Down Expand Up @@ -532,7 +542,9 @@ def _test_objective_ripple(benchmark, spline, method):
]
)
constraint = ObjectiveFunction([ForceBalance(eq)])
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
prox.build()
x = prox.x(eq)
_ = getattr(prox, method)(x, prox.constants).block_until_ready()
Expand Down
22 changes: 17 additions & 5 deletions tests/benchmarks/benchmark_gpu_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def test_proximal_jac_atf(benchmark):
grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.linspace(0.1, 1, 10))
objective = ObjectiveFunction(QuasisymmetryTwoTerm(eq, grid=grid))
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
prox.build()
x = prox.x(eq)
prox.jac_scaled_error(x, prox.constants).block_until_ready()
Expand Down Expand Up @@ -364,7 +366,11 @@ def test_proximal_jac_atf_with_eq_update(benchmark):
constraint,
eq,
perturb_options={"verbose": 3},
solve_options={"verbose": 3, "maxiter": 0},
solve_options={
"verbose": 3,
"maxiter": 0,
"solve_during_proximal_build": False,
},
)
prox.build(verbose=3)
x = prox.x(eq)
Expand All @@ -391,7 +397,9 @@ def test_proximal_freeb_compute(benchmark):
field = ToroidalMagneticField(1.0, 1.0) # just a dummy field for benchmarking
objective = ObjectiveFunction(BoundaryError(eq, field=field))
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
obj = LinearConstraintProjection(
prox, ObjectiveFunction((FixCurrent(eq), FixPressure(eq), FixPsi(eq)))
)
Expand All @@ -415,7 +423,9 @@ def test_proximal_freeb_jac(benchmark):
field = ToroidalMagneticField(1.0, 1.0) # just a dummy field for benchmarking
objective = ObjectiveFunction(BoundaryError(eq, field=field))
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
obj = LinearConstraintProjection(
prox, ObjectiveFunction((FixCurrent(eq), FixPressure(eq), FixPsi(eq)))
)
Expand Down Expand Up @@ -535,7 +545,9 @@ def _test_objective_ripple(benchmark, spline, method):
]
)
constraint = ObjectiveFunction([ForceBalance(eq)])
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
prox.build()
x = prox.x(eq)
_ = getattr(prox, method)(x, prox.constants).block_until_ready()
Expand Down
22 changes: 17 additions & 5 deletions tests/benchmarks/memory_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def test_proximal_jac_w7x_with_eq_update():
constraint,
eq,
perturb_options={"verbose": 0},
solve_options={"verbose": 0, "maxiter": 0},
solve_options={
"verbose": 0,
"maxiter": 0,
"solve_during_proximal_build": False,
},
)
prox.build(verbose=0)
x = prox.x(eq)
Expand All @@ -97,7 +101,9 @@ def test_proximal_freeb_jac():
field = ToroidalMagneticField(1.0, 1.0) # just a dummy field for benchmarking
objective = ObjectiveFunction(BoundaryError(eq, field=field))
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
obj = LinearConstraintProjection(
prox, ObjectiveFunction((FixCurrent(eq), FixPressure(eq), FixPsi(eq)))
)
Expand Down Expand Up @@ -125,7 +131,9 @@ def test_proximal_freeb_jac_batched():
jac_chunk_size=100,
)
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
obj = LinearConstraintProjection(
prox, ObjectiveFunction((FixCurrent(eq), FixPressure(eq), FixPsi(eq)))
)
Expand All @@ -152,7 +160,9 @@ def test_proximal_freeb_jac_blocked():
deriv_mode="blocked",
)
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
obj = LinearConstraintProjection(
prox, ObjectiveFunction((FixCurrent(eq), FixPressure(eq), FixPsi(eq)))
)
Expand Down Expand Up @@ -196,7 +206,9 @@ def _test_proximal_ripple(spline, method):
]
)
constraint = ObjectiveFunction([ForceBalance(eq)])
prox = ProximalProjection(objective, constraint, eq)
prox = ProximalProjection(
objective, constraint, eq, solve_options={"solve_during_proximal_build": False}
)
prox.build(verbose=0)
x = prox.x(eq)
for _ in range(3):
Expand Down
11 changes: 9 additions & 2 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def compute(self, params, constants=None):
np.random.seed(0)
objective = ObjectiveFunction(DummyObjective(things=eq), use_jit=False)
# make gradient super noisy so it stalls
objective.jac_scaled_error = lambda x, *args: objective._jac_scaled_error(
objective.build()
objective.jac_scaled_error = lambda x, *args: objective.jac_scaled_error(
x
) + 1e2 * (np.random.random((objective._dim_f, x.size)) - 0.5)

Expand Down Expand Up @@ -447,7 +448,13 @@ def compute(self, params, constants=None):
options={
"initial_trust_radius": 0.5,
"perturb_options": {"verbose": 0, "order": 1},
"solve_options": {"verbose": 0, "maxiter": 2},
"solve_options": {
"verbose": 0,
"maxiter": 2,
# Hidden kwarg just for debug/tests, to not solve
# during build
"solve_during_proximal_build": False,
},
},
)

Expand Down
Loading