WARNING:2025-04-06 19:31:14,948:jax._src.xla_bridge:971: A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.
W0406 19:31:14.948704 125696970082304 xla_bridge.py:971] A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.
The above can be fixed by unintsalling the jax that comes with flax and installing the jax[tpu].
pip3 install jax[tpu]