Skip to content

Conversation

@unalmis
Copy link
Collaborator

@unalmis unalmis commented Jul 30, 2025

Progress toward addressing #1154.

The attached benchmark must be ran on ku/nufft branch.
250x speed improvement from ~15 seconds to ~50 milliseconds.
The improvement would be more if the FourierZernike basis is padded as discussed in #1243.
Then we would avoid the 3D spectral to real transform as well as N^2 FFTs of size N.
(Then this computation would likely be in microsecond range).

@unalmis unalmis changed the base branch from master to ku/NFP July 30, 2025 20:02
@unalmis unalmis added the skip_changelog No need to update changelog on this PR label Jul 30, 2025
Base automatically changed from ku/NFP to master July 30, 2025 20:15

TODO(#1243) Do proper partial summation once the DESC
DESC basis are improved to store the padded tensor product basis.
https://github.com/PlasmaControl/DESC/issues/1243#issuecomment-3131182128.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PlasmaControl/desc-dev in #1508 header it says "figure out how to do FourierZernike". Basically until the basis is padded #1243 (comment) there is no efficient implementation without loops.

@github-actions
Copy link
Contributor

github-actions bot commented Jul 30, 2025

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    0.76 %    |     3.949e+03      |     3.979e+03      |    30.18     |       33.55        |       29.94        |
  test_proximal_jac_w7x_with_eq_update   |   -1.19 %    |     6.832e+03      |     6.751e+03      |    -81.02    |       161.84       |       160.53       |
  test_proximal_freeb_jac                |    0.07 %    |     1.320e+04      |     1.321e+04      |     9.22     |       78.41        |       76.71        |
  test_proximal_freeb_jac_blocked        |    0.41 %    |     7.602e+03      |     7.633e+03      |    30.99     |       67.50        |       68.60        |
  test_proximal_freeb_jac_batched        |   -0.65 %    |     7.619e+03      |     7.570e+03      |    -49.36    |       69.21        |       68.00        |
  test_proximal_jac_ripple               |   -0.49 %    |     7.550e+03      |     7.513e+03      |    -37.08    |       69.00        |       70.33        |
  test_proximal_jac_ripple_spline        |   -0.32 %    |     3.480e+03      |     3.469e+03      |    -11.02    |       72.02        |       71.52        |
  test_eq_solve                          |   -0.05 %    |     2.024e+03      |     2.023e+03      |    -1.04     |       124.18       |       124.57       |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@unalmis unalmis added the performance New feature or request to make the code faster label Jul 30, 2025
@unalmis unalmis self-assigned this Jul 31, 2025
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@unalmis unalmis marked this pull request as ready for review July 31, 2025 09:08
@unalmis unalmis requested review from a team, YigitElma, ddudt, dpanici, f0uriest and rahulgaur104 and removed request for a team July 31, 2025 09:08
@unalmis unalmis removed the run_benchmarks Run timing benchmarks on this PR against current master branch label Aug 13, 2025
@unalmis unalmis linked an issue Aug 13, 2025 that may be closed by this pull request
@unalmis unalmis requested a review from dpanici August 13, 2025 21:40
Co-authored-by: Dario Panici <37969854+dpanici@users.noreply.github.com>
Copy link
Member

@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so I understand:

  • map_coordinates doesn't change, just internally we use different shortcut for dealing with clebsch coordinates
  • map_clebsch_coordinates is now specialized to tensor product grids and uses partial summation to avoid re-evaluating the radial polynomials?

Can this also be used to speed up get_rtz_grid? since that's a meshgrid in clebsch coordinates?

def compute_theta_coords(
self, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs
@staticmethod
def _map_clebsch_coordinates(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be a method here? I'd vote for just using the function directly where needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to be an object method to avoid circular imports.

ResolutionWarning,
msg="High frequency lambda modes will be truncated in coordinate mapping.",
)
lmbda_minus_iota_omega = L.transform(L_lmn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this lambda-iota*omega and not just lambda?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The root-finding from desc coords to clebsch coords involves both lambda and omega

**kwargs,
)
@partial(jnp.vectorize, signature="(),(),(m)->()")
def vecroot(theta0, alpha, c_m):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'll still have the issue @YigitElma mentioned about recompilation here. The fix would be to make rootfun/jacfun/vecroot global (private) functions rather than define them locally within this function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ku/nufft I jitted the outer function as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be fine, but you are welcome to fix map_coordinates.

@unalmis
Copy link
Collaborator Author

unalmis commented Aug 13, 2025

Just so I understand:

* `map_coordinates` doesn't change, just internally we use different shortcut for dealing with clebsch coordinates

* `map_clebsch_coordinates` is now specialized to tensor product grids and uses partial summation to avoid re-evaluating the radial polynomials?

Can this also be used to speed up get_rtz_grid? since that's a meshgrid in clebsch coordinates?

Yes. Yes, but also avoid evaluating toroidal series. Yes also for the third, but I left that as is for the following reason. Tthe partial summation implemented here has a totally unnecessary FourierZernike spectral to real transform then again an unnecessary $N^2$ FFT's of size $N$. To avoid this, I suggested padding the FourierZernike basis modes in issue 1243 to make the partial summation trivial. Until the proper partial summation is implemented, I don't want to change the API of get_rtz_grid.

@unalmis unalmis requested a review from f0uriest August 13, 2025 23:30
**kwargs,
)

def _compute_iota_under_jit(self, rho, params=None, profiles=None, **kwargs):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is used in map_coordinates and Bounce2D.compute_theta

Copy link
Collaborator

@rahulgaur104 rahulgaur104 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly minor comments. Addressing them will help future development/ developers.

iota = eq._compute_iota_under_jit(coords, params, profiles, **kwargs)
rho, alpha, zeta = coords.T
omega = 0 # TODO(#568)
coords = jnp.column_stack([rho, alpha + iota * (zeta + omega), zeta])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If zeta is the generalized toroidal angle, don't we assume zeta = phi + Omega where phi is the cylindrical toroidal angle?
So shouldn't theta_PEST = alpha + iota * (zeta - omega)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Phi = zeta + omega, and theta_PEST = theta + lambda. The left hand side of these relations are defined quantities. Phi is the cylindrical toroidal angle and theta_PEST is the angle where the field lines are straight in the (theta_PEST, Phi) plane. When we mention generalizing angles, we refer to changing the meaning of the angles "zeta" and "theta". These relations must still hold and hence the stream functions must negate the change in zeta and theta.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alpha = theta_PEST - iota phi

guess=None,
period=np.inf,
lmbda,
theta0=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

theta0 is a very bad variable name. theta0 is dangerously close to theta_0 that people in the gyrokinetics communit use in flux tube codes to define the location of vanishing integrated magnetic shear.
Why did you change it from guess? I strongly recomment you change it back. guess is great because it immediately tells you what it is. theta0 requires user to figure out what is happening in the rest of the code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a rule of thumb, please don't make unnecessary changes to variable names.

Copy link
Collaborator Author

@unalmis unalmis Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. This is a new function, so nothing is an unnecessary change because nothing is a change; it is new.1
  2. theta0 is the initial guess in the root finding for theta. I will change it to guess as you requested. Whatever the name is, the public documentation next to the parameter theta0: jnp.ndarray : Optional initial guess for the computational coordinates. is there. So no one has to read code to figure out what it should be.

Footnotes

  1. The function _map_clebsch_coordinates is private and not used anywhere in DESC. If there is external code that was using this function, we are not responsible for breaking API convention.

out : ndarray
Shape (k, 3).
DESC computational coordinates [ρ, θ, ζ].
info : tuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove info? This is not a good coding practice. This is basically passive encryption.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if root finding fails for some reason and I want to debug it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the residual is calculated elsewhere, ignore this comment.

Copy link
Collaborator Author

@unalmis unalmis Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new function, see #1826 (comment), so I disagree with the statement about coding practice. Whatever code users have that uses root finding, nothing has changed. They can still get their info tuple.

In this new function, I have not added functionality to return auxiliary information about the root finding because that is impossible --- functions that are decorated with jnp.vectorize, such as this one, must return arrays. info is not an array.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not convinced that this root finding can fail by the way.

@@ -79,8 +79,8 @@ def bounce_points(pitch_inv, knots, B, dB_dz, num_well=None):

"""
intersect = polyroot_vec(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This routine find the intersection of lines of constant pitch with the magnetic field B, i.e., the bounce points.
B is an array with a shape rho, alpha, 1/lambda (inv_pitch), something, num_wells.
Is that right?

Copy link
Collaborator Author

@unalmis unalmis Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put the shape of all my stuff in the public documentation string (docstring for short) because I find it helpful. Here is the definition of B and its shape in that docstring.

B : jnp.ndarray
        Shape (..., N - 1, B.shape[-1]).
        Polynomial coefficients of the spline of B in local power basis.
        Last axis enumerates the coefficients of power series. Second to
        last axis enumerates the polynomials that compose a particular spline.

When you see Shape (..., N - 1, B.shape[-1]), in Python, that means the code is making the following contract with you, the developer:

The last two axes need to have that shape. All the leading axes ... don't matter, the code is agonstic to them. If there are leading axes, the code will simply perform whatever it does to the last two axes in a vectorized manner. That is the same contract that numpy makes with you. Numpy calls this contract its "broadcasting conventions".

Less experienced developers will mess up this broadcasting contract, and instead may promise broadcasting on the trailing axes of the form Shape (N - 1, B.shape[-1], ...). You should NEVER do that. If you broadcast on the trailing axes like that you break from numpy convention, and both your code and the user-facing code will have to have a million transposes and reshapes to get their stuff working with what you wrote.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what is the shape of B for this problem?
(..., B.shape[-3], B.shape[-2], B.shape[-1]) is an answer but I meant in terms of coordinates. Docstring does not have the shape here.

Copy link
Collaborator Author

@unalmis unalmis Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N is documented here as the number of knots. B.shape[-1] is documented here as the number of coefficients in each polynomial of the spline. If you have a cubic spline that is 4. The last axis of pitch is documented here to be the number of pitch angles. The returned shape is documented here to have its last two axes be number of pitch angles, number of wells, respectively.

-----
Magnetic field line with label α, defined by B = ∇ψ × ∇α, is determined from
α : ρ, θ, ζ ↦ θ + λ(ρ,θ,ζ) − ι(ρ) [ζ + ω(ρ,θ,ζ)]
α : ρ, θ, ζ ↦ θ + Λ(ρ,θ,ζ) − ι(ρ) [ζ + ω(ρ,θ,ζ)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused again by this equation. Specifically the + sign before omega.
Is the convention zeta = phi + omega or zeta = phi - omega, where phi is the cylindrical angle and zeta is the generalized angle.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter. One can control +f search desc/compute/_core.py for name="phi" and name="theta_PEST" to see the definitions.

@unalmis unalmis requested a review from rahulgaur104 August 14, 2025 05:20
@unalmis
Copy link
Collaborator Author

unalmis commented Aug 14, 2025

I addressed the comments. I think at developer meeting you all wanted me to add more of my pull request comments to code. Please make a new pull request and add whichever of my comments into the code that you would like to add.

@unalmis unalmis added the run_benchmarks Run timing benchmarks on this PR against current master branch label Aug 14, 2025
@unalmis unalmis requested review from f0uriest and removed request for f0uriest August 14, 2025 19:26
@unalmis unalmis merged commit 771b03c into master Aug 14, 2025
32 checks passed
@unalmis unalmis deleted the ku/partialsum branch August 14, 2025 23:15
@unalmis unalmis linked an issue Aug 14, 2025 that may be closed by this pull request
maya-avida pushed a commit that referenced this pull request Aug 21, 2025
Progress toward addressing #1154.

The [attached
benchmark](https://github.com/user-attachments/files/21745848/benchmark_partial_sum.zip)
must be ran on `ku/nufft` branch.
250x speed improvement from ~15 seconds to ~50 milliseconds.
The improvement would be more if the FourierZernike basis is padded as
discussed in #1243.
Then we would avoid the 3D spectral to real transform as well as N^2
FFTs of size N.
(Then this computation would likely be in microsecond range).
@unalmis unalmis mentioned this pull request Aug 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

P3 Highest Priority, someone is/should be actively working on this performance New feature or request to make the code faster run_benchmarks Run timing benchmarks on this PR against current master branch skip_changelog No need to update changelog on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix #1334 Improve coordinate mapping performance

5 participants