Skip to content

Conversation

@rahulgaur104
Copy link
Collaborator

@rahulgaur104 rahulgaur104 commented Jun 21, 2025

Addresses #1777

  • Add the differentiation matrix code
  • Pass the diffmatrices to compute function as a diffmat object as a part of transform
  • Solve the ideal ballooning problem using differentiation matrices and compare with the old finite-difference solver
  • Update tests to increase coverage

@rahulgaur104 rahulgaur104 linked an issue Jun 21, 2025 that may be closed by this pull request
@github-actions
Copy link
Contributor

github-actions bot commented Jun 21, 2025

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |   -1.54 %    |     3.952e+03      |     3.891e+03      |    -60.92    |       39.58        |       35.15        |
  test_proximal_jac_w7x_with_eq_update   |   -3.12 %    |     6.692e+03      |     6.483e+03      |   -208.90    |       159.30       |       161.00       |
  test_proximal_freeb_jac                |   -0.45 %    |     1.324e+04      |     1.318e+04      |    -59.69    |       83.95        |       81.63        |
  test_proximal_freeb_jac_blocked        |   -1.15 %    |     7.666e+03      |     7.578e+03      |    -88.34    |       72.62        |       73.92        |
  test_proximal_freeb_jac_batched        |   -0.09 %    |     7.591e+03      |     7.584e+03      |    -6.66     |       73.32        |       75.83        |
  test_proximal_jac_ripple               |    2.08 %    |     3.388e+03      |     3.459e+03      |    70.41     |       60.68        |       62.69        |
  test_proximal_jac_ripple_bounce1d      |    1.79 %    |     3.488e+03      |     3.551e+03      |    62.56     |       78.07        |       80.38        |
  test_eq_solve                          |    4.02 %    |     2.005e+03      |     2.086e+03      |    80.59     |       128.59       |       129.80       |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@codecov
Copy link

codecov bot commented Jun 21, 2025

Codecov Report

❌ Patch coverage is 96.32353% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.78%. Comparing base (5902967) to head (c27ee23).

Files with missing lines Patch % Lines
desc/diffmat_utils.py 96.25% 3 Missing ⚠️
desc/compute/utils.py 84.61% 2 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##           master    #1789    +/-   ##
========================================
  Coverage   95.77%   95.78%            
========================================
  Files         101      102     +1     
  Lines       27728    27855   +127     
========================================
+ Hits        26556    26680   +124     
- Misses       1172     1175     +3     
Files with missing lines Coverage Δ
desc/compute/_stability.py 99.24% <100.00%> (+0.07%) ⬆️
desc/integrals/quad_utils.py 100.00% <100.00%> (ø)
desc/objectives/_stability.py 99.27% <100.00%> (+0.01%) ⬆️
desc/compute/utils.py 97.06% <84.61%> (-0.64%) ⬇️
desc/diffmat_utils.py 96.25% <96.25%> (ø)

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@rahulgaur104 rahulgaur104 added the run_benchmarks Run timing benchmarks on this PR against current master branch label Jun 21, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Jun 21, 2025

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     -1.27 +/- 5.47     | -8.90e-03 +/- 3.83e-02 |  6.91e-01 +/- 2.4e-02  |  7.00e-01 +/- 3.0e-02  |
 test_build_transform_fft_highres        |     -1.79 +/- 3.17     | -1.68e-02 +/- 2.99e-02 |  9.25e-01 +/- 2.8e-02  |  9.41e-01 +/- 1.1e-02  |
 test_equilibrium_init_lowres            |     -2.69 +/- 2.75     | -1.27e-01 +/- 1.30e-01 |  4.60e+00 +/- 8.3e-02  |  4.73e+00 +/- 1.0e-01  |
 test_objective_compile_atf              |     -0.30 +/- 3.47     | -1.85e-02 +/- 2.16e-01 |  6.20e+00 +/- 1.3e-01  |  6.22e+00 +/- 1.7e-01  |
 test_objective_compute_atf              |     -7.00 +/- 18.44    | -1.55e-04 +/- 4.09e-04 |  2.06e-03 +/- 2.4e-04  |  2.22e-03 +/- 3.3e-04  |
 test_objective_jac_atf                  |     +0.89 +/- 2.57     | +1.56e-02 +/- 4.49e-02 |  1.77e+00 +/- 4.0e-02  |  1.75e+00 +/- 2.1e-02  |
 test_perturb_1                          |     -1.53 +/- 2.34     | -2.20e-01 +/- 3.37e-01 |  1.42e+01 +/- 1.9e-01  |  1.44e+01 +/- 2.8e-01  |
 test_proximal_jac_atf                   |     +0.52 +/- 1.45     | +2.92e-02 +/- 8.16e-02 |  5.64e+00 +/- 6.3e-02  |  5.61e+00 +/- 5.2e-02  |
 test_proximal_freeb_compute             |     -1.78 +/- 3.65     | -2.95e-03 +/- 6.06e-03 |  1.63e-01 +/- 4.5e-03  |  1.66e-01 +/- 4.1e-03  |
 test_solve_fixed_iter                   |     -0.98 +/- 1.87     | -2.75e-01 +/- 5.24e-01 |  2.77e+01 +/- 2.8e-01  |  2.80e+01 +/- 4.4e-01  |
 test_objective_compute_ripple           |     -1.11 +/- 2.72     | -2.39e-03 +/- 5.84e-03 |  2.12e-01 +/- 4.5e-03  |  2.15e-01 +/- 3.8e-03  |
 test_objective_grad_ripple              |     +0.04 +/- 5.41     | +3.70e-04 +/- 5.08e-02 |  9.40e-01 +/- 4.7e-02  |  9.40e-01 +/- 2.0e-02  |
 test_build_transform_fft_lowres         |     -1.56 +/- 1.95     | -8.80e-03 +/- 1.10e-02 |  5.54e-01 +/- 8.0e-03  |  5.63e-01 +/- 7.6e-03  |
 test_equilibrium_init_medres            |     -1.55 +/- 1.43     | -7.67e-02 +/- 7.10e-02 |  4.87e+00 +/- 5.5e-02  |  4.95e+00 +/- 4.5e-02  |
 test_equilibrium_init_highres           |     -5.92 +/- 3.78     | -3.46e-01 +/- 2.21e-01 |  5.50e+00 +/- 7.4e-02  |  5.85e+00 +/- 2.1e-01  |
 test_objective_compile_dshape_current   |     -6.08 +/- 4.99     | -2.16e-01 +/- 1.77e-01 |  3.33e+00 +/- 1.2e-01  |  3.55e+00 +/- 1.3e-01  |
 test_objective_compute_dshape_current   |     -7.69 +/- 7.93     | -6.12e-05 +/- 6.32e-05 |  7.36e-04 +/- 3.7e-05  |  7.97e-04 +/- 5.1e-05  |
 test_objective_jac_dshape_current       |     -4.41 +/- 16.00    | -1.46e-03 +/- 5.30e-03 |  3.17e-02 +/- 2.9e-03  |  3.31e-02 +/- 4.5e-03  |
+test_perturb_2                          |     -4.33 +/- 1.42     | -7.60e-01 +/- 2.49e-01 |  1.68e+01 +/- 2.2e-01  |  1.75e+01 +/- 1.3e-01  |
 test_proximal_jac_atf_with_eq_update    |     -0.38 +/- 1.18     | -5.13e-02 +/- 1.60e-01 |  1.36e+01 +/- 7.7e-02  |  1.36e+01 +/- 1.4e-01  |
 test_proximal_freeb_jac                 |     -0.10 +/- 6.67     | -4.77e-03 +/- 3.33e-01 |  4.98e+00 +/- 5.2e-02  |  4.99e+00 +/- 3.3e-01  |
 test_solve_fixed_iter_compiled          |     -1.13 +/- 1.68     | -1.05e-01 +/- 1.57e-01 |  9.22e+00 +/- 8.2e-02  |  9.32e+00 +/- 1.3e-01  |
 test_LinearConstraintProjection_build   |     +1.61 +/- 2.34     | +1.33e-01 +/- 1.94e-01 |  8.41e+00 +/- 1.0e-01  |  8.28e+00 +/- 1.6e-01  |
 test_objective_compute_ripple_bounce1d  |     -0.02 +/- 2.98     | -6.12e-05 +/- 8.62e-03 |  2.89e-01 +/- 8.1e-03  |  2.89e-01 +/- 2.8e-03  |
 test_objective_grad_ripple_bounce1d     |     -0.08 +/- 2.29     | -8.39e-04 +/- 2.52e-02 |  1.10e+00 +/- 1.9e-02  |  1.10e+00 +/- 1.7e-02  |

Github CI performance can be noisy. When evaluating the benchmarks, developers should take this into account.

@rahulgaur104
Copy link
Collaborator Author

Compare the fourier diffmat with DESC diffmat calculation. Think about the end point!

@rahulgaur104 rahulgaur104 changed the title Fourier and chebyshev differentiation matrices Differentiation matrices Jul 7, 2025
@rahulgaur104 rahulgaur104 changed the base branch from master to rg/AGNI_var September 2, 2025 19:01
x = jnp.sin(jnp.pi * jnp.flip(2 * k - n + 1) / (2 * (n - 1)))

# Affine map to [a, b] only when necessary (allows JIT constant folding)
if domain[0] != -1 or domain[1] != 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary. JIT compilation will recognize addition with static floats that reduces to 0 and multiplications/divisions with static floats that reduce to 1 as identity operations and remove them in the compiled code.1 You can check the compiled HLO code to see if curious.

Footnotes

  1. I had checked the compiled HLO code is identical for all the domain shifting for the functions in desc/integrals.

@dpanici dpanici marked this pull request as draft October 9, 2025 01:28
…function via transforms, adding an option to choose the spectral ballooning solver, moving the mapping function/automorphism to quad_utils, updating tests, removing Chebyshev diffmatrices for now
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@rahulgaur104 rahulgaur104 changed the base branch from rg/AGNI_var to master October 22, 2025 18:35
@rahulgaur104 rahulgaur104 marked this pull request as ready for review October 23, 2025 21:58
@rahulgaur104 rahulgaur104 requested review from a team and f0uriest and removed request for a team October 24, 2025 12:56


@tree_util.register_pytree_node_class
class DiffMat:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for not using IOable here and register the class manually. If you just inherit from IOable, it will handle registration, tree_flatten and tree_unflatten for you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can still add your custom hash and eq functions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason. I commited the code I was able to run.
Ok, great! Can you point me to a template to follow? Like another class that looks like this one?

Copy link
Collaborator

@YigitElma YigitElma Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any class in DESC inherits from that, but yes it is pretty hidden sometimes. This case, it should look like,

from desc.io import IOAble

class DiffMat(IOAble):
    """Some docs"""
    _static_attrs = ["_token"]

   # your __init__, __eq__ and __hash__

self._iota_keys + ["ideal ballooning lambda"], eq, iota_grid
)
self._constants = {
"diffmat": self._diffmat,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I don't see a point in having a reference to the same thing in the constants, and eventually it will be cleaner to get rid of the constants #1769

)(z)


# --no-verify a = 0.015 # [0.0, 0.1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were these for debugging?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I'll remove that.

D, _ = fourier_diffmat(nz)
Dz = (D @ D) * NFP**2

# Create identity matrices for tensor product
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are not necessary, right? You can already use the Dx Dy Dz which is identity for the special case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I think you are right. Let me see!

domain = [0, 2 * jnp.pi]
a, b = domain
h = (b - a) / n
return jnp.arange(n) * h + a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not important but not just jnp.linspace(domain[0], domain[1], n, endpoint=False)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'll change that Thanks!

nzetaperturn=200,
zeta0=None,
Neigvals=1,
diffmat=DiffMat(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you have this default, it is not creating matrices for you, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, no matrices are being created but the transform inside the registered compute function needs "diffmat". I don't know how to avoid passing an empty diffmat because the logic in desc/compute/utils.py will then throw an error.

I was able to make this work and not the other option (where I don't pass anything).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, diffmat = DiffMat() doesn't change the internal computation and since there is a default, it not not necessary, the code will run and do the same operations either way. Can you add this as a comment or can you show how to properly create these matrices and pass it that way? I would prefer the latter.

["ideal ballooning lambda"],
eq,
grid,
diffmat=constants["diffmat"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly confused and I have couple questions. Most objectives have a constant grid in rho, theta, zeta, then the transforms stay a constant. Here during the optimization rho, alpha, zeta stays the same and you find the equivalent rho, theta, zeta grid. How does this affect the differentiation matrices? Do they always stay constant? Or how do you know they have proper distribution which they are generated for in the first place?

I would appreciate it if you could explain it to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, no problem. Yes, the rho and zeta coordinate stay the same when we go from the fixed (rho, alpha, zeta) -> (rho, theta, zeta) grid. Since the equation is always being solved at the same zeta points, the differentiation matrices are fixed and have the exact same distribution for all fieldlines and flux surfaces.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh so you are not using alpha or theta derivatives, right? Ok, then I understand now.

Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax ensure compile time eval + stop gradient

potentially could work here

@YigitElma YigitElma marked this pull request as draft December 11, 2025 05:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run_benchmarks Run timing benchmarks on this PR against current master branch waiting for other PRs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Differentiation matrices

5 participants