Skip to content

google/orbax

Orbax - Checkpointing for JAX Models

PyPI version Documentation Status

Installation | Quickstart | Documentation | Support

Orbax provides common checkpointing and persistence utilities for JAX users.

Refer to our full documentation here.

Installation

Orbax is available on PyPI as separate domain-specific packages:

Checkpointing

Install from PyPI:

pip install orbax-checkpoint

Or install the latest version directly from GitHub at HEAD:

pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'

Exporting

Install from PyPI:

pip install orbax-export

Or install the latest version directly from GitHub at HEAD:

pip install 'git+https://github.com/google/orbax/#subdirectory=export'

Quickstart

import jax
from orbax.checkpoint import v1 as ocp

# Define your pytree state (e.g. weights, optimizer state)
state = {'a': jax.numpy.ones(2), 'b': 42}

# Save the state
ocp.save('/tmp/my_checkpoint', state)

# Restore the state
restored_state = ocp.load('/tmp/my_checkpoint')

Orbax includes a checkpointing library oriented towards JAX users, supporting a variety of different features required by different frameworks, including asynchronous checkpointing, standard/custom types, and flexible storage formats. We aim to provide a highly customizable and composable API which maximizes flexibility for diverse use cases.

Support

Please report any issues or request support using our issue tracker.

Please also reach out to orbax-dev@google.com directly for help or with any questions about Orbax.

Citing Orbax

@misc{gaffney2026orbaxdistributedcheckpointingjax,
      title={Orbax: Distributed Checkpointing with JAX},
      author={Colin Gaffney and Shutong Li and Daniel Ng and Anastasia Petrushkina and Niket Kumar and Adam Cogdell and Mridul Sahu and Yaning Liang and Nikhil Bansal and Justin Pan and Angel Mau and Abhishek Agrawal and Marco Berlot and Ruoxin Sang and Kiranbir Sodhia and Rakesh Iyer},
      year={2026},
      eprint={2605.23066},
      archivePrefix={arXiv},
      primaryClass={cs.DC},
      url={https://arxiv.org/abs/2605.23066},
}

Existing Users

Orbax Checkpointing is used extensively across JAX machine learning frameworks and model implementations.

Google Projects

  • Flax (Google's flexible and expressive neural network library for JAX)
  • Gemma (Open foundation models by Google DeepMind)
  • Kauldron (Google Research training and evaluation framework)
  • PaxML (Google's high-performance framework for training large-scale JAX models)
  • T5X (Google's JAX framework for high-performance sequence models)
  • MaxText (Google's high-performance, scalable JAX LLM implementation)
  • MaxDiffusion (Stable diffusion JAX training library optimized for Cloud TPUs)
  • Tunix (Google's JAX-native library for LLM post-training)
  • Numerous Google-internal ML frameworks

Non-Google Projects

  • AXLearn (Apple's high-performance deep learning library built on top of JAX)
  • openpi (Robotics foundation models by Physical Intelligence)

About

Orbax provides common checkpointing and persistence utilities for JAX users

Topics

Resources

License

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Contributors

Languages