-
Notifications
You must be signed in to change notification settings - Fork 41
Partial summation in coordinate mapping #1826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| TODO(#1243) Do proper partial summation once the DESC | ||
| DESC basis are improved to store the padded tensor product basis. | ||
| https://github.com/PlasmaControl/DESC/issues/1243#issuecomment-3131182128. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@PlasmaControl/desc-dev in #1508 header it says "figure out how to do FourierZernike". Basically until the basis is padded #1243 (comment) there is no efficient implementation without loops.
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 0.76 % | 3.949e+03 | 3.979e+03 | 30.18 | 33.55 | 29.94 |
test_proximal_jac_w7x_with_eq_update | -1.19 % | 6.832e+03 | 6.751e+03 | -81.02 | 161.84 | 160.53 |
test_proximal_freeb_jac | 0.07 % | 1.320e+04 | 1.321e+04 | 9.22 | 78.41 | 76.71 |
test_proximal_freeb_jac_blocked | 0.41 % | 7.602e+03 | 7.633e+03 | 30.99 | 67.50 | 68.60 |
test_proximal_freeb_jac_batched | -0.65 % | 7.619e+03 | 7.570e+03 | -49.36 | 69.21 | 68.00 |
test_proximal_jac_ripple | -0.49 % | 7.550e+03 | 7.513e+03 | -37.08 | 69.00 | 70.33 |
test_proximal_jac_ripple_spline | -0.32 % | 3.480e+03 | 3.469e+03 | -11.02 | 72.02 | 71.52 |
test_eq_solve | -0.05 % | 2.024e+03 | 2.023e+03 | -1.04 | 124.18 | 124.57 |For the memory plots, go to the summary of |
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Co-authored-by: Dario Panici <37969854+dpanici@users.noreply.github.com>
f0uriest
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just so I understand:
map_coordinatesdoesn't change, just internally we use different shortcut for dealing with clebsch coordinatesmap_clebsch_coordinatesis now specialized to tensor product grids and uses partial summation to avoid re-evaluating the radial polynomials?
Can this also be used to speed up get_rtz_grid? since that's a meshgrid in clebsch coordinates?
| def compute_theta_coords( | ||
| self, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs | ||
| @staticmethod | ||
| def _map_clebsch_coordinates( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need to be a method here? I'd vote for just using the function directly where needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has to be an object method to avoid circular imports.
desc/equilibrium/coords.py
Outdated
| ResolutionWarning, | ||
| msg="High frequency lambda modes will be truncated in coordinate mapping.", | ||
| ) | ||
| lmbda_minus_iota_omega = L.transform(L_lmn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this lambda-iota*omega and not just lambda?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The root-finding from desc coords to clebsch coords involves both lambda and omega
desc/equilibrium/coords.py
Outdated
| **kwargs, | ||
| ) | ||
| @partial(jnp.vectorize, signature="(),(),(m)->()") | ||
| def vecroot(theta0, alpha, c_m): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you'll still have the issue @YigitElma mentioned about recompilation here. The fix would be to make rootfun/jacfun/vecroot global (private) functions rather than define them locally within this function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In ku/nufft I jitted the outer function as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should be fine, but you are welcome to fix map_coordinates.
Yes. Yes, but also avoid evaluating toroidal series. Yes also for the third, but I left that as is for the following reason. Tthe partial summation implemented here has a totally unnecessary FourierZernike spectral to real transform then again an unnecessary |
| **kwargs, | ||
| ) | ||
|
|
||
| def _compute_iota_under_jit(self, rho, params=None, profiles=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is used in map_coordinates and Bounce2D.compute_theta
rahulgaur104
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly minor comments. Addressing them will help future development/ developers.
| iota = eq._compute_iota_under_jit(coords, params, profiles, **kwargs) | ||
| rho, alpha, zeta = coords.T | ||
| omega = 0 # TODO(#568) | ||
| coords = jnp.column_stack([rho, alpha + iota * (zeta + omega), zeta]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If zeta is the generalized toroidal angle, don't we assume zeta = phi + Omega where phi is the cylindrical toroidal angle?
So shouldn't theta_PEST = alpha + iota * (zeta - omega)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Phi = zeta + omega, and theta_PEST = theta + lambda. The left hand side of these relations are defined quantities. Phi is the cylindrical toroidal angle and theta_PEST is the angle where the field lines are straight in the (theta_PEST, Phi) plane. When we mention generalizing angles, we refer to changing the meaning of the angles "zeta" and "theta". These relations must still hold and hence the stream functions must negate the change in zeta and theta.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alpha = theta_PEST - iota phi
desc/equilibrium/coords.py
Outdated
| guess=None, | ||
| period=np.inf, | ||
| lmbda, | ||
| theta0=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
theta0 is a very bad variable name. theta0 is dangerously close to theta_0 that people in the gyrokinetics communit use in flux tube codes to define the location of vanishing integrated magnetic shear.
Why did you change it from guess? I strongly recomment you change it back. guess is great because it immediately tells you what it is. theta0 requires user to figure out what is happening in the rest of the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a rule of thumb, please don't make unnecessary changes to variable names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- This is a new function, so nothing is an unnecessary change because nothing is a change; it is new.1
theta0is the initial guess in the root finding fortheta. I will change it toguessas you requested. Whatever the name is, the public documentation next to the parametertheta0: jnp.ndarray : Optional initial guess for the computational coordinates.is there. So no one has to read code to figure out what it should be.
Footnotes
-
The function
_map_clebsch_coordinatesis private and not used anywhere in DESC. If there is external code that was using this function, we are not responsible for breaking API convention. ↩
| out : ndarray | ||
| Shape (k, 3). | ||
| DESC computational coordinates [ρ, θ, ζ]. | ||
| info : tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you remove info? This is not a good coding practice. This is basically passive encryption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if root finding fails for some reason and I want to debug it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the residual is calculated elsewhere, ignore this comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a new function, see #1826 (comment), so I disagree with the statement about coding practice. Whatever code users have that uses root finding, nothing has changed. They can still get their info tuple.
In this new function, I have not added functionality to return auxiliary information about the root finding because that is impossible --- functions that are decorated with jnp.vectorize, such as this one, must return arrays. info is not an array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not convinced that this root finding can fail by the way.
| @@ -79,8 +79,8 @@ def bounce_points(pitch_inv, knots, B, dB_dz, num_well=None): | |||
|
|
|||
| """ | |||
| intersect = polyroot_vec( | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This routine find the intersection of lines of constant pitch with the magnetic field B, i.e., the bounce points.
B is an array with a shape rho, alpha, 1/lambda (inv_pitch), something, num_wells.
Is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put the shape of all my stuff in the public documentation string (docstring for short) because I find it helpful. Here is the definition of B and its shape in that docstring.
B : jnp.ndarray
Shape (..., N - 1, B.shape[-1]).
Polynomial coefficients of the spline of B in local power basis.
Last axis enumerates the coefficients of power series. Second to
last axis enumerates the polynomials that compose a particular spline.When you see Shape (..., N - 1, B.shape[-1]), in Python, that means the code is making the following contract with you, the developer:
The last two axes need to have that shape. All the leading axes ... don't matter, the code is agonstic to them. If there are leading axes, the code will simply perform whatever it does to the last two axes in a vectorized manner. That is the same contract that numpy makes with you. Numpy calls this contract its "broadcasting conventions".
Less experienced developers will mess up this broadcasting contract, and instead may promise broadcasting on the trailing axes of the form Shape (N - 1, B.shape[-1], ...). You should NEVER do that. If you broadcast on the trailing axes like that you break from numpy convention, and both your code and the user-facing code will have to have a million transposes and reshapes to get their stuff working with what you wrote.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what is the shape of B for this problem?
(..., B.shape[-3], B.shape[-2], B.shape[-1]) is an answer but I meant in terms of coordinates. Docstring does not have the shape here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
N is documented here as the number of knots. B.shape[-1] is documented here as the number of coefficients in each polynomial of the spline. If you have a cubic spline that is 4. The last axis of pitch is documented here to be the number of pitch angles. The returned shape is documented here to have its last two axes be number of pitch angles, number of wells, respectively.
| ----- | ||
| Magnetic field line with label α, defined by B = ∇ψ × ∇α, is determined from | ||
| α : ρ, θ, ζ ↦ θ + λ(ρ,θ,ζ) − ι(ρ) [ζ + ω(ρ,θ,ζ)] | ||
| α : ρ, θ, ζ ↦ θ + Λ(ρ,θ,ζ) − ι(ρ) [ζ + ω(ρ,θ,ζ)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused again by this equation. Specifically the + sign before omega.
Is the convention zeta = phi + omega or zeta = phi - omega, where phi is the cylindrical angle and zeta is the generalized angle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latter. One can control +f search desc/compute/_core.py for name="phi" and name="theta_PEST" to see the definitions.
|
I addressed the comments. I think at developer meeting you all wanted me to add more of my pull request comments to code. Please make a new pull request and add whichever of my comments into the code that you would like to add. |
Progress toward addressing #1154. The [attached benchmark](https://github.com/user-attachments/files/21745848/benchmark_partial_sum.zip) must be ran on `ku/nufft` branch. 250x speed improvement from ~15 seconds to ~50 milliseconds. The improvement would be more if the FourierZernike basis is padded as discussed in #1243. Then we would avoid the 3D spectral to real transform as well as N^2 FFTs of size N. (Then this computation would likely be in microsecond range).
Progress toward addressing #1154.
The attached benchmark must be ran on
ku/nufftbranch.250x speed improvement from ~15 seconds to ~50 milliseconds.
The improvement would be more if the FourierZernike basis is padded as discussed in #1243.
Then we would avoid the 3D spectral to real transform as well as N^2 FFTs of size N.
(Then this computation would likely be in microsecond range).