-
I'm working on computational neuroscience and I do a lot of large scale computations (perhaps not as big as genomics folks). These computations are usually simple linear algebra manipulations, parallelized to a bunch of CPUs. Meanwhile, our lab have a large number of GPUs for use. What would be a good way to take advantage of the easy usage of xarray and JAX? In my workflow, I already have a lot of vectorized calls and it's saving me a lot of time. But these are still a large number of small linear algebra operations that might see huge gains when performed on GPUs. A naive way is to simply convert to jax in a vectorized function. For example, I vectorize the calls to a number of brain scans. [num_subjects, num_conditions, x, y, z, features] where xyz is the spatial dimensions. I have a function But this is likely going to incur a large conversion overhead. In the case that entire data fits into memory un-chunked, what do you think is a better approach? |
Beta Was this translation helpful? Give feedback.
Replies: 6 comments 28 replies
-
if you repeatedly would be converting to / from jax, it would probably be best to keep everything in jax and have According to jax-ml/jax#17107, support for doing so seems to be close but not quite there yet. In particular we'd need jax to implement the array API (i.e., define |
Beta Was this translation helpful? Give feedback.
-
Hi @keewis, the Jax PR (jax-ml/jax#16099) is merged, so is there an updated answer to this discussion? |
Beta Was this translation helpful? Give feedback.
-
Hi ! Cheers |
Beta Was this translation helpful? Give feedback.
-
Starting from JAX Two additional issues for Xarray + JAX integration have been identified: Issue 2 has an example solution in the GraphCast project. The main challenge I've run into there is that many JAX libraries (e.g. Diffrax) use tree manipulation patterns that produce "dummy trees" like boolean masks. If a I can think of three solutions to this:
|
Beta Was this translation helpful? Give feedback.
-
I pushed a |
Beta Was this translation helpful? Give feedback.
-
xarray/xarray/core/variable.py Line 338 in 44c7c15 |
Beta Was this translation helpful? Give feedback.
if you repeatedly would be converting to / from jax, it would probably be best to keep everything in jax and have
xarray
wrap that.According to jax-ml/jax#17107, support for doing so seems to be close but not quite there yet. In particular we'd need jax to implement the array API (i.e., define
__array_namespace__
) – see jax-ml/jax#16099 a in-progress PR.