Open
Description
🐛 Bug
Trying to update the libtpu nightly to anything after 04/25 hangs the TPU tests.
To Reproduce
- Set libtpu nightly to after 04/25 in
setup.py
file - Run python test/test_operations.py
- This leads to hang
Expected behavior
Tests should pass.
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
- torch_xla version: 04/25 nightly
Additional context
Libtpu turned the flag --xla_tpu_use_enhanced_launch_barrier
to default true
value. This flag ensures that each device that the pjrt_executable is compiled for is executing the same code by doing an allreduce on the run_id.
I think that when running Compile we use all the available PjRt devices to compile
When executing the computation, the barrier probably expects all the devices to be running the same computation due to the device assignment. Creating an issue to verify and fix this.