-
Notifications
You must be signed in to change notification settings - Fork 41
Differentiation matrices #1789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Differentiation matrices #1789
Conversation
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | -1.54 % | 3.952e+03 | 3.891e+03 | -60.92 | 39.58 | 35.15 |
test_proximal_jac_w7x_with_eq_update | -3.12 % | 6.692e+03 | 6.483e+03 | -208.90 | 159.30 | 161.00 |
test_proximal_freeb_jac | -0.45 % | 1.324e+04 | 1.318e+04 | -59.69 | 83.95 | 81.63 |
test_proximal_freeb_jac_blocked | -1.15 % | 7.666e+03 | 7.578e+03 | -88.34 | 72.62 | 73.92 |
test_proximal_freeb_jac_batched | -0.09 % | 7.591e+03 | 7.584e+03 | -6.66 | 73.32 | 75.83 |
test_proximal_jac_ripple | 2.08 % | 3.388e+03 | 3.459e+03 | 70.41 | 60.68 | 62.69 |
test_proximal_jac_ripple_bounce1d | 1.79 % | 3.488e+03 | 3.551e+03 | 62.56 | 78.07 | 80.38 |
test_eq_solve | 4.02 % | 2.005e+03 | 2.086e+03 | 80.59 | 128.59 | 129.80 |For the memory plots, go to the summary of |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1789 +/- ##
========================================
Coverage 95.77% 95.78%
========================================
Files 101 102 +1
Lines 27728 27855 +127
========================================
+ Hits 26556 26680 +124
- Misses 1172 1175 +3
🚀 New features to boost your workflow:
|
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_midres | -1.27 +/- 5.47 | -8.90e-03 +/- 3.83e-02 | 6.91e-01 +/- 2.4e-02 | 7.00e-01 +/- 3.0e-02 |
test_build_transform_fft_highres | -1.79 +/- 3.17 | -1.68e-02 +/- 2.99e-02 | 9.25e-01 +/- 2.8e-02 | 9.41e-01 +/- 1.1e-02 |
test_equilibrium_init_lowres | -2.69 +/- 2.75 | -1.27e-01 +/- 1.30e-01 | 4.60e+00 +/- 8.3e-02 | 4.73e+00 +/- 1.0e-01 |
test_objective_compile_atf | -0.30 +/- 3.47 | -1.85e-02 +/- 2.16e-01 | 6.20e+00 +/- 1.3e-01 | 6.22e+00 +/- 1.7e-01 |
test_objective_compute_atf | -7.00 +/- 18.44 | -1.55e-04 +/- 4.09e-04 | 2.06e-03 +/- 2.4e-04 | 2.22e-03 +/- 3.3e-04 |
test_objective_jac_atf | +0.89 +/- 2.57 | +1.56e-02 +/- 4.49e-02 | 1.77e+00 +/- 4.0e-02 | 1.75e+00 +/- 2.1e-02 |
test_perturb_1 | -1.53 +/- 2.34 | -2.20e-01 +/- 3.37e-01 | 1.42e+01 +/- 1.9e-01 | 1.44e+01 +/- 2.8e-01 |
test_proximal_jac_atf | +0.52 +/- 1.45 | +2.92e-02 +/- 8.16e-02 | 5.64e+00 +/- 6.3e-02 | 5.61e+00 +/- 5.2e-02 |
test_proximal_freeb_compute | -1.78 +/- 3.65 | -2.95e-03 +/- 6.06e-03 | 1.63e-01 +/- 4.5e-03 | 1.66e-01 +/- 4.1e-03 |
test_solve_fixed_iter | -0.98 +/- 1.87 | -2.75e-01 +/- 5.24e-01 | 2.77e+01 +/- 2.8e-01 | 2.80e+01 +/- 4.4e-01 |
test_objective_compute_ripple | -1.11 +/- 2.72 | -2.39e-03 +/- 5.84e-03 | 2.12e-01 +/- 4.5e-03 | 2.15e-01 +/- 3.8e-03 |
test_objective_grad_ripple | +0.04 +/- 5.41 | +3.70e-04 +/- 5.08e-02 | 9.40e-01 +/- 4.7e-02 | 9.40e-01 +/- 2.0e-02 |
test_build_transform_fft_lowres | -1.56 +/- 1.95 | -8.80e-03 +/- 1.10e-02 | 5.54e-01 +/- 8.0e-03 | 5.63e-01 +/- 7.6e-03 |
test_equilibrium_init_medres | -1.55 +/- 1.43 | -7.67e-02 +/- 7.10e-02 | 4.87e+00 +/- 5.5e-02 | 4.95e+00 +/- 4.5e-02 |
test_equilibrium_init_highres | -5.92 +/- 3.78 | -3.46e-01 +/- 2.21e-01 | 5.50e+00 +/- 7.4e-02 | 5.85e+00 +/- 2.1e-01 |
test_objective_compile_dshape_current | -6.08 +/- 4.99 | -2.16e-01 +/- 1.77e-01 | 3.33e+00 +/- 1.2e-01 | 3.55e+00 +/- 1.3e-01 |
test_objective_compute_dshape_current | -7.69 +/- 7.93 | -6.12e-05 +/- 6.32e-05 | 7.36e-04 +/- 3.7e-05 | 7.97e-04 +/- 5.1e-05 |
test_objective_jac_dshape_current | -4.41 +/- 16.00 | -1.46e-03 +/- 5.30e-03 | 3.17e-02 +/- 2.9e-03 | 3.31e-02 +/- 4.5e-03 |
+test_perturb_2 | -4.33 +/- 1.42 | -7.60e-01 +/- 2.49e-01 | 1.68e+01 +/- 2.2e-01 | 1.75e+01 +/- 1.3e-01 |
test_proximal_jac_atf_with_eq_update | -0.38 +/- 1.18 | -5.13e-02 +/- 1.60e-01 | 1.36e+01 +/- 7.7e-02 | 1.36e+01 +/- 1.4e-01 |
test_proximal_freeb_jac | -0.10 +/- 6.67 | -4.77e-03 +/- 3.33e-01 | 4.98e+00 +/- 5.2e-02 | 4.99e+00 +/- 3.3e-01 |
test_solve_fixed_iter_compiled | -1.13 +/- 1.68 | -1.05e-01 +/- 1.57e-01 | 9.22e+00 +/- 8.2e-02 | 9.32e+00 +/- 1.3e-01 |
test_LinearConstraintProjection_build | +1.61 +/- 2.34 | +1.33e-01 +/- 1.94e-01 | 8.41e+00 +/- 1.0e-01 | 8.28e+00 +/- 1.6e-01 |
test_objective_compute_ripple_bounce1d | -0.02 +/- 2.98 | -6.12e-05 +/- 8.62e-03 | 2.89e-01 +/- 8.1e-03 | 2.89e-01 +/- 2.8e-03 |
test_objective_grad_ripple_bounce1d | -0.08 +/- 2.29 | -8.39e-04 +/- 2.52e-02 | 1.10e+00 +/- 1.9e-02 | 1.10e+00 +/- 1.7e-02 |Github CI performance can be noisy. When evaluating the benchmarks, developers should take this into account. |
|
Compare the fourier diffmat with DESC diffmat calculation. Think about the end point! |
desc/diffmat_utils.py
Outdated
| x = jnp.sin(jnp.pi * jnp.flip(2 * k - n + 1) / (2 * (n - 1))) | ||
|
|
||
| # Affine map to [a, b] only when necessary (allows JIT constant folding) | ||
| if domain[0] != -1 or domain[1] != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not necessary. JIT compilation will recognize addition with static floats that reduces to 0 and multiplications/divisions with static floats that reduce to 1 as identity operations and remove them in the compiled code.1 You can check the compiled HLO code to see if curious.
Footnotes
-
I had checked the compiled HLO code is identical for all the domain shifting for the functions in desc/integrals. ↩
…function via transforms, adding an option to choose the spectral ballooning solver, moving the mapping function/automorphism to quad_utils, updating tests, removing Chebyshev diffmatrices for now
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
|
||
|
|
||
| @tree_util.register_pytree_node_class | ||
| class DiffMat: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for not using IOable here and register the class manually. If you just inherit from IOable, it will handle registration, tree_flatten and tree_unflatten for you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can still add your custom hash and eq functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No reason. I commited the code I was able to run.
Ok, great! Can you point me to a template to follow? Like another class that looks like this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any class in DESC inherits from that, but yes it is pretty hidden sometimes. This case, it should look like,
from desc.io import IOAble
class DiffMat(IOAble):
"""Some docs"""
_static_attrs = ["_token"]
# your __init__, __eq__ and __hash__| self._iota_keys + ["ideal ballooning lambda"], eq, iota_grid | ||
| ) | ||
| self._constants = { | ||
| "diffmat": self._diffmat, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I don't see a point in having a reference to the same thing in the constants, and eventually it will be cleaner to get rid of the constants #1769
| )(z) | ||
|
|
||
|
|
||
| # --no-verify a = 0.015 # [0.0, 0.1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Were these for debugging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. I'll remove that.
| D, _ = fourier_diffmat(nz) | ||
| Dz = (D @ D) * NFP**2 | ||
|
|
||
| # Create identity matrices for tensor product |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are not necessary, right? You can already use the Dx Dy Dz which is identity for the special case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I think you are right. Let me see!
| domain = [0, 2 * jnp.pi] | ||
| a, b = domain | ||
| h = (b - a) / n | ||
| return jnp.arange(n) * h + a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not important but not just jnp.linspace(domain[0], domain[1], n, endpoint=False)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'll change that Thanks!
| nzetaperturn=200, | ||
| zeta0=None, | ||
| Neigvals=1, | ||
| diffmat=DiffMat(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you have this default, it is not creating matrices for you, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, no matrices are being created but the transform inside the registered compute function needs "diffmat". I don't know how to avoid passing an empty diffmat because the logic in desc/compute/utils.py will then throw an error.
I was able to make this work and not the other option (where I don't pass anything).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, diffmat = DiffMat() doesn't change the internal computation and since there is a default, it not not necessary, the code will run and do the same operations either way. Can you add this as a comment or can you show how to properly create these matrices and pass it that way? I would prefer the latter.
| ["ideal ballooning lambda"], | ||
| eq, | ||
| grid, | ||
| diffmat=constants["diffmat"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am slightly confused and I have couple questions. Most objectives have a constant grid in rho, theta, zeta, then the transforms stay a constant. Here during the optimization rho, alpha, zeta stays the same and you find the equivalent rho, theta, zeta grid. How does this affect the differentiation matrices? Do they always stay constant? Or how do you know they have proper distribution which they are generated for in the first place?
I would appreciate it if you could explain it to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, no problem. Yes, the rho and zeta coordinate stay the same when we go from the fixed (rho, alpha, zeta) -> (rho, theta, zeta) grid. Since the equation is always being solved at the same zeta points, the differentiation matrices are fixed and have the exact same distribution for all fieldlines and flux surfaces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh so you are not using alpha or theta derivatives, right? Ok, then I understand now.
dpanici
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax ensure compile time eval + stop gradient
potentially could work here
Addresses #1777
diffmatobject as a part oftransform