-
Notifications
You must be signed in to change notification settings - Fork 197
Description
Environment
Python version: 3.11.11
JAX Version: 0.5.2
GPU Details:
- Local System: RTX 4080 (Ampere architecture) with NVIDIA driver 550, AMD Ryzen 7 CPU on Ubuntu.
- Google Colab: Tesla T4 (Turing architecture) with NVIDIA driver 550, Intel Xeon CPU.
Issue
I noticed that the RL performance in the manipulation.ipynb example wildly varied between the sample Colab instance and my PC. Namely, the task performance degrades rapidly after the first iteration on my PC, while it generally keeps improving on Colab:
Colab (Tesla T4) | my PC (RTX 4080) |
---|---|
![]() |
![]() |
Also, when I ran the train_jax_ppo.py
script for the default LeapCubeReorient task, it did not train- the average episode length even drops to zero after the first iteration:
I tried to match the software versions (e.g. the Python libraries jax
, jaxlib
, jax-cuda12-plugin
, brax
, mujoco
, flax
, and also the NVIDIA driver version and CUDA version) but that didn't fix this discrepancy.
Fix
After some searching, this Q&A pointed me to this answer which suggests to increase the matmul precision in jax to "high". This still didn't fix the performance for the LeapCubeReorient task (it did work for the first PandaPickCubeOrientation task), so I then tried:
jax.config.update('jax_default_matmul_precision', 'highest')
which fixed the performance issue, as you can see from this Weights & Biases screenshot:
Cause
This seems to result from the fact that on the RTX 4080 (Ampere), JAX defaults to using TF32 (which an Ampere-specific format that is lower precision compared to full float32, as noted in issues like jax-ml/jax#12008) because of hardware optimizations. On the Tesla T4 (Turing), full float32 precision is used by default.
I will immediately close this issue as I managed to solve it, but I will still post it for future users. I think it would be helpful as setting it to use the highest matmul precision could be added to the README, as it means that currently mujoco playground doesn't work "out of the box" (rather, it causes this hard to find and debug issue) for Ampere architecture users.