-
Notifications
You must be signed in to change notification settings - Fork 41
Add Deflation methods #2044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Deflation methods #2044
Conversation
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 3.78 % | 3.873e+03 | 4.020e+03 | 146.25 | 39.40 | 36.86 |
test_proximal_jac_w7x_with_eq_update | 0.78 % | 6.484e+03 | 6.535e+03 | 50.72 | 165.44 | 164.23 |
test_proximal_freeb_jac | 0.03 % | 1.320e+04 | 1.320e+04 | 3.70 | 84.43 | 83.44 |
test_proximal_freeb_jac_blocked | -0.17 % | 7.531e+03 | 7.518e+03 | -13.11 | 74.11 | 75.43 |
test_proximal_freeb_jac_batched | 0.23 % | 7.492e+03 | 7.509e+03 | 17.27 | 73.00 | 73.61 |
test_proximal_jac_ripple | -2.19 % | 3.551e+03 | 3.474e+03 | -77.86 | 64.58 | 65.91 |
test_proximal_jac_ripple_bounce1d | -0.47 % | 3.558e+03 | 3.541e+03 | -16.83 | 76.15 | 77.76 |
test_eq_solve | 2.22 % | 2.001e+03 | 2.045e+03 | 44.43 | 96.20 | 93.77 |For the memory plots, go to the summary of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So as I see it there are two main ways of applying deflation
- As a multiplicative factor on the objective eg
f(x) -> M(x,x*)f(x) - As an additional inequality constraint
M(x,x*)<r
The new DeflationOperator objective seems to cover the 2nd case, but for the first case ForceBalanceDeflated only works for equilibrium problems. I think it would be better as a sort of "wrapper objective" that can be applied to any objective to multiply the deflation operator.
We could possibly combine the two and make it a single objective like
class DeflationOperator:
"""Multiplicative or constraint type deflation"""
def __init__(self, objective=None, ...):
self.objective = objective
def compute(self, x):
if self.objective is not None:
f = self.objective.compute(x)
else:
f = 1
return M(x,x*)*fThis would cover both cases, either treating the deflation as an extra constraint (with objective=None) or applying multiplicative deflation to an arbitrary objective (eg by passing objective=ForceBalance())
|
also don't forget to add |
…n when all None are passed, add check for lower bound on deflation as constraint
| self._dim_f = 1 | ||
|
|
||
| self._is_none_mask = [] | ||
| self._is_not_none_mask = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think you ever have conditionals on the mask in the compute function? it looks like its just used as the where arg to prod/sum?
desc/objectives/_generic.py
Outdated
| # if wrapping an objective, but all things are None, make deflation do | ||
| # nothing when multiplying f, so here we add 1 to it as it is 0 right now | ||
| # if all things are None | ||
| deflation_parameter += 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you might want this to be += jnp.invert(self._not_all_things_to_deflate_are_None) that way it only adds 1 if all the deflated things are None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes you are right. the logic was getting confusing
Deflation method motivation: find multiple solutions to non-convex optimization problems (which can include certain equilibirum solves)
This PR adds ways to apply deflation methods in stellarator optimization and equilibrium solving through the new
DeflationOperatorobjectiveDeflationOperatorwhose cost is simply M(x;y) = 1/(x-y)^p + sigma (to add as constraints to an optimization like in Tarek 2022 work). This can be used as a standalone metric, or another_Objectivecan be passed to it to wrap it and return as the cost M(x;y)f(x) where f(x) is that_Objective's compute value, like is done in usual deflationReferences:
"exp"deflation typeTODO
ForceBalanceDeflatedto use pytree inputs forparams_to_deflate_withFuture work for another PR:
_equilibriumas attribute of DeflationOperator and test using it in proximal-lsq-exact