You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: paper.md
+5-7Lines changed: 5 additions & 7 deletions
Original file line number
Diff line number
Diff line change
@@ -25,7 +25,7 @@ bibliography: paper.bib
25
25
# Summary
26
26
27
27
The tensor framework JAX [@jax] combines expressivity and performance while retaining an accessible pure Python interface.
28
-
Expressivity is achieved by treating functions as first-class objects, while efficiency is obtained by compiling to machine code Just-Ahead-Of-Time.
28
+
Expressivity is achieved by treating functions as first-class objects, while efficiency is obtained by compiling to machine code just-ahead-of-time.
29
29
30
30
However, machine learning and (high-performance) scientific computing are often conducted on different hardware stacks: Machine learning is typically done on few highly parallel units (GPUs or TPUs) connected to a single host CPU, while scientific models tend to run on clusters of dozens to thousands of CPUs.
31
31
Unfortunately, support from JAX and the underlying compiler XLA is more mature in the former case.
@@ -59,8 +59,7 @@ Two real-world use cases for `mpi4jax` are the ocean model Veros [@hafner_veros_
59
59
60
60
The implementation of a primitive in `mpi4jax` consists of two parts:
61
61
62
-
1. A Python module, registering a new primitive with JAX. JAX primitives consist of an _abstract evaluation_ rule and several _translation rules_. The former is used by the compiler to infer the output shapes and data types without running the actual computation, while _translation rules_ determine the specific computational kernel and prepare the input buffers. A different _translation rule_ is necessary for every type of backend, such as CPUs, GPUs and TPUs.
63
-
On specific primitives we also define the _transposition rule_ in order to support reverse mode Automatic Differentiation.
62
+
1. A Python module, registering a new primitive with JAX. JAX primitives consist of an _abstract evaluation_ rule and several _translation rules_. Abstract evaluation rules are used by the compiler to infer the output shapes and data types without running the actual computation, while _translation rules_ determine the specific computational kernel and prepare the input buffers.
64
63
65
64
In particular, we need to ensure that all numerical input data is of the expected type (e.g., by converting Python integers to the C type `uintptr_t`) before passing it on to XLA. A different translation rule is necessary for every type of backend, such as CPUs, GPUs and TPUs.
66
65
@@ -197,11 +196,10 @@ As we can see, switching from NumPy to JAX already yields a substantial speedup,
197
196
198
197
In this paper, we introduced `mpi4jax`, which allows zero-copy communication of JAX-owned data. `mpi4jax` provides an implementation of the most important MPI operations in a way that is usable from JAX compiled code.
199
198
200
-
However, JAX is much more than just a JIT compiler. It is also a full-fledged differentiable programming framework by providing powerful tools for automatic differentiation (e.g. via `jax.grad`, `jax.vjp`, and `jax.jvp`). Differentiable programming is a promising new paradigm to combine advances in machine learning and physical modelling [@diffprog1; @diffprog2], and being able to freely distribute those models among different nodes will allow for even more powerful applications.
199
+
However, JAX is much more than just a JIT compiler. It is also a full-fledged differentiable programming framework by providing tools for automatic differentiation (e.g. via `jax.grad`, `jax.vjp`, and `jax.jvp`). Differentiable programming is a promising new paradigm to combine advances in machine learning and physical modelling [@diffprog1; @diffprog2], and being able to freely distribute those models among different nodes will allow for even more powerful applications.
201
200
202
-
So far, `mpi4jax` only supports differentiating through global sums via the `allreduce` primitive (one of the main operations used in distributed matrix-vector products) and combined send and receive (`sendrecv`) operations in forward and reverse mode.
203
-
204
-
This would eventually enable fully differentiable, distributed physical simulations without additional user code.
201
+
So far, `mpi4jax` supports differentiating through global sums via the `allreduce` primitive (one of the main operations used in distributed matrix-vector products) and combined send and receive (`sendrecv`) operations in forward and reverse mode.
202
+
However, it should be possible with some additional work to preserve gradient information through most MPI operations, by propagating gradients through several processes. This would eventually enable fully differentiable, distributed physical simulations without additional user code.
0 commit comments