Open
Description
Unless I am mistaken, it is only possible to use the distributed backend (initialised with jax.distributed.initialize
) with the GPU and TPU backends.
However, I believe that Tensorflow, thus XLA should also support the CPU backend.
Would it be possible to support it in Jax as well so that it will be possible to use it with pjit
& co?