Skip to content

Commit 368657e

Browse files
committed
📝
1 parent bf76ff2 commit 368657e

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

paper.md

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ bibliography: paper.bib
2525
# Summary
2626

2727
The tensor framework JAX [@jax] combines expressivity and performance while retaining an accessible pure Python interface.
28-
Expressivity is achieved by treating functions as first-class objects, while efficiency is obtained by compiling to machine code Just-Ahead-Of-Time.
28+
Expressivity is achieved by treating functions as first-class objects, while efficiency is obtained by compiling to machine code just-ahead-of-time.
2929

3030
However, machine learning and (high-performance) scientific computing are often conducted on different hardware stacks: Machine learning is typically done on few highly parallel units (GPUs or TPUs) connected to a single host CPU, while scientific models tend to run on clusters of dozens to thousands of CPUs.
3131
Unfortunately, support from JAX and the underlying compiler XLA is more mature in the former case.
@@ -59,8 +59,7 @@ Two real-world use cases for `mpi4jax` are the ocean model Veros [@hafner_veros_
5959

6060
The implementation of a primitive in `mpi4jax` consists of two parts:
6161

62-
1. A Python module, registering a new primitive with JAX. JAX primitives consist of an _abstract evaluation_ rule and several _translation rules_. The former is used by the compiler to infer the output shapes and data types without running the actual computation, while _translation rules_ determine the specific computational kernel and prepare the input buffers. A different _translation rule_ is necessary for every type of backend, such as CPUs, GPUs and TPUs.
63-
On specific primitives we also define the _transposition rule_ in order to support reverse mode Automatic Differentiation.
62+
1. A Python module, registering a new primitive with JAX. JAX primitives consist of an _abstract evaluation_ rule and several _translation rules_. Abstract evaluation rules are used by the compiler to infer the output shapes and data types without running the actual computation, while _translation rules_ determine the specific computational kernel and prepare the input buffers.
6463

6564
In particular, we need to ensure that all numerical input data is of the expected type (e.g., by converting Python integers to the C type `uintptr_t`) before passing it on to XLA. A different translation rule is necessary for every type of backend, such as CPUs, GPUs and TPUs.
6665

@@ -197,11 +196,10 @@ As we can see, switching from NumPy to JAX already yields a substantial speedup,
197196

198197
In this paper, we introduced `mpi4jax`, which allows zero-copy communication of JAX-owned data. `mpi4jax` provides an implementation of the most important MPI operations in a way that is usable from JAX compiled code.
199198

200-
However, JAX is much more than just a JIT compiler. It is also a full-fledged differentiable programming framework by providing powerful tools for automatic differentiation (e.g. via `jax.grad`, `jax.vjp`, and `jax.jvp`). Differentiable programming is a promising new paradigm to combine advances in machine learning and physical modelling [@diffprog1; @diffprog2], and being able to freely distribute those models among different nodes will allow for even more powerful applications.
199+
However, JAX is much more than just a JIT compiler. It is also a full-fledged differentiable programming framework by providing tools for automatic differentiation (e.g. via `jax.grad`, `jax.vjp`, and `jax.jvp`). Differentiable programming is a promising new paradigm to combine advances in machine learning and physical modelling [@diffprog1; @diffprog2], and being able to freely distribute those models among different nodes will allow for even more powerful applications.
201200

202-
So far, `mpi4jax` only supports differentiating through global sums via the `allreduce` primitive (one of the main operations used in distributed matrix-vector products) and combined send and receive (`sendrecv`) operations in forward and reverse mode.
203-
204-
This would eventually enable fully differentiable, distributed physical simulations without additional user code.
201+
So far, `mpi4jax` supports differentiating through global sums via the `allreduce` primitive (one of the main operations used in distributed matrix-vector products) and combined send and receive (`sendrecv`) operations in forward and reverse mode.
202+
However, it should be possible with some additional work to preserve gradient information through most MPI operations, by propagating gradients through several processes. This would eventually enable fully differentiable, distributed physical simulations without additional user code.
205203

206204
# Acknowledgements
207205

0 commit comments

Comments
 (0)