Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
c3f399d
fix the custom_jvp problem with new jax version, we don't need to mul…
YigitElma Oct 4, 2024
641207d
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 4, 2024
7e2c884
revert back to latest JAX version
YigitElma Oct 4, 2024
ec9a9f8
update jax_test versions
YigitElma Oct 4, 2024
b71afc6
have jax_tests use python 3.10 so can test recent versions
dpanici Oct 6, 2024
979a5fb
fix python version
dpanici Oct 6, 2024
449eb70
revert back to multiplying by 0s
YigitElma Oct 6, 2024
2a16943
Merge branch 'yge/customjvp_fix' of github.com:PlasmaControl/DESC int…
YigitElma Oct 6, 2024
21e5f8d
fix python version syntax
YigitElma Oct 6, 2024
8137caf
fix jax test to install jax first and then decide other dependency ve…
YigitElma Oct 6, 2024
1b48da4
update matplotlib latest version on mpl test, force 3.7.2 for none ca…
YigitElma Oct 6, 2024
52957ad
fix jax test
YigitElma Oct 6, 2024
531c5c4
fix incorrect name of dev-requirements
dpanici Oct 7, 2024
c4e75f7
update jax dependency test
YigitElma Oct 7, 2024
743b0fe
Merge branch 'yge/customjvp_fix' of github.com:PlasmaControl/DESC int…
YigitElma Oct 7, 2024
fed17e5
clean-up the dependency installation and print installed dependencies…
YigitElma Oct 7, 2024
e225308
take jax dependency to the top of the file
YigitElma Oct 7, 2024
9631353
add custom_jvp to zernike_radial directly, this won't work for jax 0.…
YigitElma Oct 7, 2024
891c914
fix missing argument dr problem
YigitElma Oct 7, 2024
ff3a0e9
re-add the _jacobi custom_jvp for the test but actually not needed fo…
YigitElma Oct 7, 2024
b4218ac
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 11, 2024
49c2ac8
use nondiff_argnums for zernike_radial custom_jvp
YigitElma Oct 15, 2024
e55d84e
Merge branch 'yge/customjvp_fix' of github.com:PlasmaControl/DESC int…
YigitElma Oct 15, 2024
4c87756
revert back to old version, nondiff creates some unexpected tracers f…
YigitElma Oct 15, 2024
7256bce
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 15, 2024
937a547
fix auxilary return values in a hacky way until JAX people reply
YigitElma Oct 19, 2024
7660d57
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 19, 2024
92bf3d1
apply the same fix to root_scalar, add docstring, apply same to numpy…
YigitElma Oct 19, 2024
e10b2e4
fix constant_offset_surface function
YigitElma Oct 19, 2024
c40eaf1
make full_output case also differentiable, increase coverage
YigitElma Oct 19, 2024
b7675e6
fix zeta phi problem causing nans
YigitElma Oct 20, 2024
6364dba
make test_map_coordinates_derivative test different cases of map_coor…
YigitElma Oct 20, 2024
b376c67
move matplotlib changes to new PR
YigitElma Oct 20, 2024
b6061ce
move matplotlib changes to new PR
YigitElma Oct 20, 2024
895589a
add root and root_scalar tests as well as their derivatives
YigitElma Oct 20, 2024
3d94830
revert float64 stuff
YigitElma Oct 21, 2024
7a276e5
bump minimum version of jax to 0.4.24
YigitElma Oct 21, 2024
31d0f8a
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 21, 2024
0542f0e
just checking jax version
YigitElma Oct 21, 2024
3c83c04
try to test scipy
YigitElma Oct 21, 2024
732a951
back to previous version
YigitElma Oct 21, 2024
ee24838
fix conda requirements
YigitElma Oct 21, 2024
844bbc0
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 22, 2024
d613453
update minimum jax requirement, remove cpu flag
YigitElma Oct 23, 2024
8dd5c3f
update minimum jax requirement, remove cpu flag
YigitElma Oct 23, 2024
c6a0055
solve to the permutation todo on map coordinates
YigitElma Oct 23, 2024
38ddf4f
update installation instructions, solve sphinx-argparse TODO, add ipy…
YigitElma Oct 23, 2024
a910067
update requirements
YigitElma Oct 23, 2024
d07c047
reverse the doc requirements
YigitElma Oct 23, 2024
3b59ff1
remove duplicated line
YigitElma Oct 23, 2024
0c9e1b3
remove the for deleting the jax[cpu]
YigitElma Oct 24, 2024
22d3899
rephrase ipykernel comment
YigitElma Oct 24, 2024
382b1cf
revert changes for TODO
YigitElma Oct 24, 2024
a7a9114
Merge branch 'master' into yge/customjvp_fix
dpanici Oct 24, 2024
9441cee
Merge branch 'master' into yge/customjvp_fix
dpanici Oct 28, 2024
75c1614
remove ipykernel stuff, if people need they can install, we don't hav…
YigitElma Oct 28, 2024
86456a5
Merge branch 'yge/customjvp_fix' of github.com:PlasmaControl/DESC int…
YigitElma Oct 28, 2024
3825acc
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 28, 2024
1968e23
update changelog, rename inbasis
YigitElma Oct 29, 2024
c508da8
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions .github/workflows/jax_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,35 @@ jobs:
strategy:
fail-fast: false
matrix:
jax-version: [0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5,
0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11,
0.3.12, 0.3.13, 0.3.14, 0.3.15, 0.3.16, 0.3.17,
0.3.19, 0.3.20, 0.3.21, 0.3.22, 0.3.23, 0.3.24,
0.3.25, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5,
0.4.6, 0.4.7, 0.4.8, 0.4.9, 0.4.10, 0.4.11,
0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18,
jax-version: [0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17,
0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23,
0.4.24, 0.4.25]
0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28, 0.4.29,
0.4.30, 0.4.31, 0.4.33, 0.4.34, 0.4.35]
# 0.4.32 is not available on PyPI
# earlier jax versions are not compatible with other
# dependencies as of 2024-10-04
group: [1, 2]
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: 3.9
cache: pip
- name: Install dependencies
python-version: '3.10'
- name: Upgrade pip
run: |
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.5.0
- name: Remove jax
- name: Install dependencies with given JAX version
run: |
pip uninstall jax jaxlib -y
- name: install jax
sed -i '/jax/d' ./requirements.txt
sed -i '1i\jax[cpu] == ${{ matrix.jax-version }}' ./requirements.txt
cat ./requirements.txt
pip install -r ./devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Verify dependencies
run: |
pip install "jax[cpu]==${{ matrix.jax-version }}"
python --version
pip --version
pip list
- name: Test with pytest
run: |
pwd
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Bug Fixes

- Fixes bugs that occur when saving asymmetric equilibria as wout files
- Fixes bug that occurs when using ``VMECIO.plot_vmec_comparison`` to compare to an asymmetric wout file
- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version

Deprecations

Expand Down
73 changes: 56 additions & 17 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,12 @@
treedef_is_leaf,
)

trapezoid = (
jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid
)
if hasattr(jnp, "trapezoid"):
trapezoid = jnp.trapezoid # for JAX 0.4.26 and later
elif hasattr(jax.scipy, "integrate"):
trapezoid = jax.scipy.integrate.trapezoid
else:
trapezoid = jnp.trapz # for older versions of JAX, deprecated by jax 0.4.16

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.
Expand Down Expand Up @@ -200,6 +203,7 @@ def root_scalar(
maxiter_ls=5,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.

Expand Down Expand Up @@ -227,6 +231,9 @@ def root_scalar(
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x, *args) -> x'.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.

Returns
-------
Expand Down Expand Up @@ -271,18 +278,25 @@ def bodyfun(state):
xk1, fk1 = backtrack(xk1, fk1, d)
return xk1, fk1, k1 + 1

state = guess, res(guess), 0
state = guess, res(guess), 0.0
state = jax.lax.while_loop(condfun, bodyfun, state)
return state[0], state[1:]
if full_output:
return state[0], state[1:]
else:
return state[0]

def tangent_solve(g, y):
A = jax.jacfwd(g)(y)
return y / A

x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (abs(res), niter)
if full_output:
x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (abs(res), niter)
else:
x = jax.lax.custom_root(res, x0, solve, tangent_solve, has_aux=False)
return x

def root(
fun,
Expand All @@ -294,6 +308,7 @@ def root(
maxiter_ls=0,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.

Expand Down Expand Up @@ -321,6 +336,9 @@ def root(
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x, *args) -> 1d array.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.

Returns
-------
Expand Down Expand Up @@ -388,19 +406,26 @@ def bodyfun(state):
state = (
jnp.atleast_1d(jnp.asarray(guess)),
jnp.atleast_1d(resfun(guess)),
0,
0.0,
)
state = jax.lax.while_loop(condfun, bodyfun, state)
return state[0], state[1:]
if full_output:
return state[0], state[1:]
else:
return state[0]

def tangent_solve(g, y):
A = jnp.atleast_2d(jax.jacfwd(g)(y))
return _lstsq(A, jnp.atleast_1d(y))

x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (safenorm(res), niter)
if full_output:
x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (safenorm(res), niter)
else:
x = jax.lax.custom_root(res, x0, solve, tangent_solve, has_aux=False)
return x


# we can't really test the numpy backend stuff in automated testing, so we ignore it
Expand Down Expand Up @@ -711,6 +736,7 @@ def root_scalar(
maxiter_ls=5,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.

Expand Down Expand Up @@ -738,6 +764,9 @@ def root_scalar(
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x) -> x'.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.

Returns
-------
Expand All @@ -750,7 +779,10 @@ def root_scalar(
out = scipy.optimize.root_scalar(
fun, args, x0=x0, fprime=jac, xtol=tol, rtol=tol
)
return out.root, out
if full_output:
return out.root, out
else:
return out.root

def root(
fun,
Expand All @@ -762,6 +794,7 @@ def root(
maxiter_ls=0,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.

Expand Down Expand Up @@ -789,6 +822,9 @@ def root(
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x, *args) -> 1d array.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.

Returns
-------
Expand All @@ -803,7 +839,10 @@ def root(
will solve it in a least squares sense.
"""
out = scipy.optimize.root(fun, x0, args, jac=jac, tol=tol)
return out.x, out
if full_output:
return out.x, out
else:
return out.x

def flatnonzero(a, size=None, fill_value=0):
"""A numpy implementation of jnp.flatnonzero."""
Expand Down
12 changes: 6 additions & 6 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,7 @@ def zernike_radial(r, l, m, dr=0):
"Analytic radial derivatives of Zernike polynomials for order>4 "
+ "have not been implemented."
)
return s * jnp.where((l - m) % 2 == 0, out, 0)
return s * jnp.where((l - m) % 2 == 0, out, 0.0)


def power_coeffs(l):
Expand Down Expand Up @@ -1732,7 +1732,7 @@ def _binom_body_fun(i, b_n):
return b


@custom_jvp
@functools.partial(custom_jvp, nondiff_argnums=(4,))
@jit
@jnp.vectorize
def _jacobi(n, alpha, beta, x, dx=0):
Expand Down Expand Up @@ -1804,13 +1804,13 @@ def _jacobi_body_fun(kk, d_p_a_b_x):


@_jacobi.defjvp
def _jacobi_jvp(x, xdot):
(n, alpha, beta, x, dx) = x
(ndot, alphadot, betadot, xdot, dxdot) = xdot
def _jacobi_jvp(dx, x, xdot):
(n, alpha, beta, x) = x
(*_, xdot) = xdot
f = _jacobi(n, alpha, beta, x, dx)
df = _jacobi(n, alpha, beta, x, dx + 1)
# in theory n, alpha, beta, dx aren't differentiable (they're integers)
# but marking them as non-diff argnums seems to cause escaped tracer values.
# probably a more elegant fix, but just setting those derivatives to zero seems
# to work fine.
return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted redundant 0 multiplications because in some cases, this gives an error saying float0 cannot be used in math operations...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that was an issue before? but I guess if you ran it with the jax versions and this passes then this is fine

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no, the problem there was different. Back then we thought the problem was for zernike_radial, but it was root and root_scalar

return f, df * xdot
36 changes: 27 additions & 9 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,18 @@ def fixup(y, *args):
fixup=fixup,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
)
)
# See description here
# https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
# except we make sure properly handle periodic coordinates.
yk, (res, niter) = vecroot(yk, coords)
if full_output:
yk, (res, niter) = vecroot(yk, coords)
else:
yk = vecroot(yk, coords)

out = compute(yk, outbasis)
if full_output:
Expand Down Expand Up @@ -363,18 +367,28 @@ def fixup(x, *args):
fixup=fixup,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
)
)
rho, theta_PEST, zeta = coords.T
theta, (res, niter) = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
if full_output:
theta, (res, niter) = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
else:
theta = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
return out, (res, niter)
Expand Down Expand Up @@ -466,6 +480,7 @@ def fixup(x, *args):
fixup=fixup,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
)
Expand All @@ -474,7 +489,10 @@ def fixup(x, *args):
if guess is None:
# Assume λ=0 for default initial guess.
guess = alpha + iota * zeta
theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota)
if full_output:
theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota)
else:
theta = vecroot(guess, alpha, rho, zeta, iota)

out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
Expand Down
15 changes: 12 additions & 3 deletions desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,10 +741,19 @@ def fun_jax(zeta_hat, theta, zeta):
n, r, r_offset = n_and_r_jax(nodes)
return jnp.arctan(r_offset[0, 1] / r_offset[0, 0]) - zeta

vecroot = jit(vmap(lambda x0, *p: root_scalar(fun_jax, x0, jac=None, args=p)))
zetas, (res, niter) = vecroot(
grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]
vecroot = jit(
vmap(
lambda x0, *p: root_scalar(
fun_jax, x0, jac=None, args=p, full_output=full_output
)
)
)
if full_output:
zetas, (res, niter) = vecroot(
grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]
)
else:
zetas = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2])

zetas = np.asarray(zetas)
nodes = np.vstack((np.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T
Expand Down
3 changes: 1 addition & 2 deletions devtools/dev-requirements_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ dependencies:
- pip:
# Conda only parses a single list of pip requirements.
# If two pip lists are given, all but the last list is skipped.
- jax >= 0.4.24, < 0.5.0
- diffrax >= 0.4.1
- interpax >= 0.3.3
- jax[cpu] >= 0.3.2, < 0.5.0
- nvgpu
- orthax
- plotly >= 5.16, < 6.0
Expand All @@ -29,7 +29,6 @@ dependencies:
- qicna @ git+https://github.com/rogeriojorge/pyQIC/
- black[jupyter] = 24.3.0


# building the docs
- nbsphinx == 0.8.12
- pandoc
Expand Down
Loading