-
Couldn't load subscription status.
- Fork 0
Description
Centuries ago, @inailuig merged MPI support within XLA in openxla/xla#7849 , however it has largely remained a secret since then, and only the most astute hackers would dare use this.
Recently, he benchmarked some of our codes using mpi4jax vs a native jax sharding implementation using MPI as the backend transport layer, finding identical performance.
As the cost of maintaining two implementations of many things in NetKet has been slowing down its development for the last two years, and mpi4jax is very complex to setup on gpus those days, we decided to drop mpi4jax support over there.
As jax has very terrible native CPU parallelisation, we still want MPI, but without mpi4jax and that is easy to use.
This is an attempt to get there, experimental for now.
cc @dionhaefner