-
Notifications
You must be signed in to change notification settings - Fork 41
Make map_coordinates differentiable for JAX 0.4.34
#1293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c3f399d
641207d
7e2c884
ec9a9f8
b71afc6
979a5fb
449eb70
2a16943
21e5f8d
8137caf
1b48da4
52957ad
531c5c4
c4e75f7
743b0fe
fed17e5
e225308
9631353
891c914
ff3a0e9
b4218ac
49c2ac8
e55d84e
4c87756
7256bce
937a547
7660d57
92bf3d1
e10b2e4
c40eaf1
b7675e6
6364dba
b376c67
b6061ce
895589a
3d94830
7a276e5
31d0f8a
0542f0e
3c83c04
732a951
ee24838
844bbc0
d613453
8dd5c3f
c6a0055
38ddf4f
a910067
d07c047
3b59ff1
0c9e1b3
22d3899
382b1cf
a7a9114
9441cee
75c1614
86456a5
3825acc
1968e23
c508da8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
YigitElma marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1573,7 +1573,7 @@ def zernike_radial(r, l, m, dr=0): | |
| "Analytic radial derivatives of Zernike polynomials for order>4 " | ||
| + "have not been implemented." | ||
| ) | ||
| return s * jnp.where((l - m) % 2 == 0, out, 0) | ||
| return s * jnp.where((l - m) % 2 == 0, out, 0.0) | ||
|
|
||
|
|
||
| def power_coeffs(l): | ||
|
|
@@ -1732,7 +1732,7 @@ def _binom_body_fun(i, b_n): | |
| return b | ||
|
|
||
|
|
||
| @custom_jvp | ||
| @functools.partial(custom_jvp, nondiff_argnums=(4,)) | ||
| @jit | ||
| @jnp.vectorize | ||
| def _jacobi(n, alpha, beta, x, dx=0): | ||
|
|
@@ -1804,13 +1804,13 @@ def _jacobi_body_fun(kk, d_p_a_b_x): | |
|
|
||
|
|
||
| @_jacobi.defjvp | ||
| def _jacobi_jvp(x, xdot): | ||
| (n, alpha, beta, x, dx) = x | ||
| (ndot, alphadot, betadot, xdot, dxdot) = xdot | ||
| def _jacobi_jvp(dx, x, xdot): | ||
| (n, alpha, beta, x) = x | ||
| (*_, xdot) = xdot | ||
| f = _jacobi(n, alpha, beta, x, dx) | ||
| df = _jacobi(n, alpha, beta, x, dx + 1) | ||
| # in theory n, alpha, beta, dx aren't differentiable (they're integers) | ||
| # but marking them as non-diff argnums seems to cause escaped tracer values. | ||
| # probably a more elegant fix, but just setting those derivatives to zero seems | ||
| # to work fine. | ||
| return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I deleted redundant 0 multiplications because in some cases, this gives an error saying
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought that was an issue before? but I guess if you ran it with the jax versions and this passes then this is fine
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah no, the problem there was different. Back then we thought the problem was for |
||
| return f, df * xdot | ||
Uh oh!
There was an error while loading. Please reload this page.