Skip to content

Performance worse on Ampere architectures (e.g. RTX 4080) compared to Tesla T4 (used in Colab) due to matmul precision discrepancy #86

@Yasu31

Description

@Yasu31

Environment

Python version: 3.11.11
JAX Version: 0.5.2
GPU Details:

  • Local System: RTX 4080 (Ampere architecture) with NVIDIA driver 550, AMD Ryzen 7 CPU on Ubuntu.
  • Google Colab: Tesla T4 (Turing architecture) with NVIDIA driver 550, Intel Xeon CPU.

Issue

I noticed that the RL performance in the manipulation.ipynb example wildly varied between the sample Colab instance and my PC. Namely, the task performance degrades rapidly after the first iteration on my PC, while it generally keeps improving on Colab:

Colab (Tesla T4) my PC (RTX 4080)
Image Image

Also, when I ran the train_jax_ppo.py script for the default LeapCubeReorient task, it did not train- the average episode length even drops to zero after the first iteration:
Image

I tried to match the software versions (e.g. the Python libraries jax, jaxlib, jax-cuda12-plugin, brax, mujoco, flax, and also the NVIDIA driver version and CUDA version) but that didn't fix this discrepancy.

Fix

After some searching, this Q&A pointed me to this answer which suggests to increase the matmul precision in jax to "high". This still didn't fix the performance for the LeapCubeReorient task (it did work for the first PandaPickCubeOrientation task), so I then tried:

jax.config.update('jax_default_matmul_precision', 'highest')

which fixed the performance issue, as you can see from this Weights & Biases screenshot:

Image

Cause

This seems to result from the fact that on the RTX 4080 (Ampere), JAX defaults to using TF32 (which an Ampere-specific format that is lower precision compared to full float32, as noted in issues like jax-ml/jax#12008) because of hardware optimizations. On the Tesla T4 (Turing), full float32 precision is used by default.

I will immediately close this issue as I managed to solve it, but I will still post it for future users. I think it would be helpful as setting it to use the highest matmul precision could be added to the README, as it means that currently mujoco playground doesn't work "out of the box" (rather, it causes this hard to find and debug issue) for Ampere architecture users.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions