Skip to content

Commit d09daf4

Browse files
dionhaefnerPhilipVinc
authored andcommitted
📝
1 parent fbc4a26 commit d09daf4

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

paper.md

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,41 +24,46 @@ bibliography: paper.bib
2424

2525
# Summary
2626

27-
The tensor framework JAX [@jax] shows excellent performance on both machine learning and scientific computing workloads, while all user code is written in pure Python.
27+
The tensor framework JAX [@jax] combines expressivity and performance while providing an accessible pure Python interface.
28+
In particular, JAX is expressive due to its clean, functional design, and performant due to a powerful JIT (just-in-time) compiler.
2829

29-
However, machine learning and high-performance computing are still being conducted on very different hardware stacks. While machine learning is typically done on few highly parallel units (GPUs or TPUs), high-performance workloads such as physical models tend to run on clusters of dozens to thousands of CPUs. Unfortunately, support from JAX and the underlying compiler XLA is much more mature in the former case. Notably, there is no built-in solution to communicate data between different nodes that is as sophisticated as the widely used MPI (Message Passing Interface) libraries [@mpistandard].
30+
However, machine learning and (high-performance) scientific computing are often conducted on different hardware stacks: Machine learning is typically done on few highly parallel units (GPUs or TPUs) connected to a single host CPU, while scientific models tend to run on clusters of dozens to thousands of CPUs.
31+
Unfortunately, support from JAX and the underlying compiler XLA is much more mature in the former case.
32+
Notably, there is so far no built-in solution to communicate data between different nodes that is as sophisticated as the widely used MPI (Message Passing Interface) libraries [@mpistandard].
3033

31-
In this letter we attempt to fill this gap by introducing `mpi4jax`, a Python library bringing first-class support of several MPI operations to Jax.
32-
This is achieved by defining a set of new primitive functions matching MPI's operations, instructing Jax how to transform them and providing a lean native code to execute them.
33-
This means that users can communicate arbitrary JAX data without performance or usability penalties.
34+
We attempt to fill this gap and introduce `mpi4jax`, a Python library bringing first-class support for the most important MPI operations to JAX.
35+
We achieve this by defining a set of primitive functions matching MPI's operations, instructing JAX how to transform them and providing a native implementation to execute them.
36+
This has the result that users can communicate arbitrary JAX data without performance or usability penalties.
3437
In particular, `mpi4jax` is able to communicate data without copying from CPU and GPU memory (if built against a CUDA-aware MPI library) between one or multiple hosts (e.g. via an Infiniband network on a cluster).
3538

3639
This also means that existing applications using e.g. NumPy and `mpi4py` can be ported seamlessly to the JAX ecosystem for potentially significant performance gains.
3740

3841
# Statement of Need
3942

40-
For decades, high-performance computing has been done in low-level programming languages like Fortran or C.
41-
But the ubiquity of Python is starting to spill into this domain as well, thanks to its simplicity and large number of libraries.
42-
With a combination of NumPy [@numpy] and `mpi4py` [@mpi4py], Python users can build massively parallel applications without delving into low-level programming languages, which is often advantageous when human time is more valuable than computer time. But it is of course unsatisfying (and costly) to leave possible performance on the table.
43+
For decades, high-performance computing has been done primarily in low-level programming languages like Fortran or C.
44+
But the ubiquity of Python is starting to spill into this domain as well, thanks to its strong library ecosystem and wide adoption throughout the sciences.
45+
46+
With a combination of NumPy [@numpy] and `mpi4py` [@mpi4py], Python users can already build massively parallel applications without delving into low-level programming languages, which is often advantageous when human time is more valuable than computer time. But it is of course unsatisfying (and costly) to leave possible performance on the table.
4347

4448
Google's JAX library leverages the XLA compiler and supports just-in-time compilation (JIT) of Python code to XLA primitives. [The result is highly competitive performance on both CPU and GPU](https://github.com/dionhaefner/pyhpc-benchmarks) [@pyhpc-benchmarks]. This gets us close to the dream scenario of high-performance computing --- low-level performance in high-level code. With a strong performance baseline on single devices, the only thing missing is easy scalability to massively parallel hardware stacks, which we supply here.
4549

4650
Two real-world use cases for `mpi4jax` are the ocean model Veros [@hafner_veros_2018] and the machine learning toolkit for many-body quantum systems NetKet [@carleo_netket_2019]:
4751

4852
- In the case of Veros, MPI primitives are needed to communicate overlapping grid cells between processes. Communication primitives are buried deep into the physical subroutines. Therefore, refactoring the codebase to leave `jax.jit` every time data needs to be communicated would severely break the control flow of the model and incur a hefty performance loss (in addition to the cost of copying data from and to JAX). Through `mpi4jax`, it is possible to apply the JIT compiler to whole subroutines to avoid this entirely.
4953

50-
- In the case of NetKet, the desire to achieve the highest efficiency for Natural Gradient Optimisation requires finding the solution of a large linear system $A\bm{x}=\bm{y}$. However, the matrix $A$ is determined by running Automatic Differentiation on a Neural Network Model whose inputs might be distributed across several computing nodes and GPUs. Therefore, to write in a simple yet efficient way the action of $A$ the need to differentiate through distributed reduction operations inside of a linear solver arises.
51-
54+
- In the case of NetKet, a high efficiency algorithm for natural gradient optimization requires finding the solution of a large linear system $A\bm{x}=\bm{y}$. The matrix $A$ is determined by running automatic differentiation on a neural network model whose inputs might be distributed across several computing nodes and GPUs. Therefore, the need to differentiate through distributed reduction operations inside of a linear solver arises.
5255

5356
# Implementation
5457

55-
`Mpi4jax` combines JAX's custom call mechanism with `mpi4py.libmpi` (which exposes MPI C primitives as Cython callables).
58+
`mpi4jax` combines JAX's custom call mechanism with `mpi4py.libmpi` (which exposes MPI C primitives as Cython callables).
5659

5760
The implementation of a primitive in `mpi4jax` consists of two parts:
5861

59-
1. A Python module that registers a new primitive with JAX. JAX primitives consist of several parts, such as an abstract evaluation rule (used to infer output shapes and data types), and 2 translation rules (one for each CPU and GPU) that convert inputs to the appropriate XLA-compatible types. Optionally, we can also define transpose and differentiation rules (if applicable, see Outlook section).
62+
1. A Python module that registers a new primitive with JAX. JAX primitives consist of several parts, such as an *abstract evaluation* rule and several *translation rules*. The abstract evaluation rule is used by the compiler to infer output shapes and data types without running the actual computation, while translation rules supply the specific computational kernel to be called and prepare input buffers.
63+
64+
In particular, we need to ensure that all numerical input data is of the expected type (e.g., by converting Python integers to the C type `uintptr_t`) before passing it on to XLA. A different translation rule is necessary for every type of backend, such as CPUs, GPUs and TPUs.
6065

61-
In particular, we need to ensure that all numerical input data is of the expected type (e.g., by converting Python integers to the C type `uintptr_t`) before passing it on to XLA.
66+
On specific primitives we also define a transposition and JVP (Jacobian-vector product) rule to support forward and reverse mode automatic differentiation.
6267

6368
2. A Cython [@cython] function that casts raw input arguments passed by XLA to their true C type, so they can be passed on to MPI. On CPU, arguments are given in the form of arrays of void pointers, `void**`, so we use static casts for conversion. On GPU, input data is given as a raw char array, `char*`, which we deserialize to a custom Cython `struct` whose fields represent the input data.
6469

@@ -191,10 +196,9 @@ As we can see, switching from NumPy to JAX already yields a substantial speedup,
191196

192197
In this paper, we introduced `mpi4jax`, which allows zero-copy communication of JAX-owned data. `mpi4jax` provides an implementation of the most important MPI operations in a way that is usable from JAX compiled code.
193198

194-
However, JAX is much more than just a JIT compiler: by providing powerful tools for automatic differentiation (`jax.vjp` and `jax.jvp`) and auto-vectorization (`jax.vmap`), it is a full-fledged Differentiable Programming framework.
195-
Differentiable programming in particular is a promising new paradigm to combine advances in machine learning and physical modelling [@diffprog1; @diffprog2], and being able to freely distribute those models among different nodes will allow even more powerful algorithms.
199+
However, JAX is much more than just a JIT compiler. It is also a full-fledged differentiable programming framework by providing powerful tools for automatic differentiation (e.g. via `jax.grad`, `jax.vjp`, and `jax.jvp`) and supports auto-vectorization transformations (`jax.vmap`). Differentiable programming in particular is a promising new paradigm to combine advances in machine learning and physical modelling [@diffprog1; @diffprog2], and being able to freely distribute those models among different nodes will allow even more powerful applications.
196200

197-
So far, `mpi4jax` only supports differentiating through global sums via the `allreduce` primitive, the main operation occurring in distributed linear algebra.
201+
So far, `mpi4jax` only supports differentiating through global sums via the `allreduce` primitive (one of the main operations occurring in distributed linear algebra).
198202
However, it should be possible with some additional work to compute the gradients of generic send / receive operations, by propagating gradients through several processes.
199203

200204
This would eventually enable fully differentiable, distributed physical simulations without additional user code.

0 commit comments

Comments
 (0)