Installation | Quickstart | Documentation | Support
Orbax provides common checkpointing and persistence utilities for JAX users.
Refer to our full documentation here.
Orbax is available on PyPI as separate domain-specific packages:
Install from PyPI:
pip install orbax-checkpointOr install the latest version directly from GitHub at HEAD:
pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'Install from PyPI:
pip install orbax-exportOr install the latest version directly from GitHub at HEAD:
pip install 'git+https://github.com/google/orbax/#subdirectory=export'import jax
from orbax.checkpoint import v1 as ocp
# Define your pytree state (e.g. weights, optimizer state)
state = {'a': jax.numpy.ones(2), 'b': 42}
# Save the state
ocp.save('/tmp/my_checkpoint', state)
# Restore the state
restored_state = ocp.load('/tmp/my_checkpoint')Orbax includes a checkpointing library oriented towards JAX users, supporting a variety of different features required by different frameworks, including asynchronous checkpointing, standard/custom types, and flexible storage formats. We aim to provide a highly customizable and composable API which maximizes flexibility for diverse use cases.
Please report any issues or request support using our issue tracker.
Please also reach out to orbax-dev@google.com directly for help or with any questions about Orbax.
@misc{gaffney2026orbaxdistributedcheckpointingjax,
title={Orbax: Distributed Checkpointing with JAX},
author={Colin Gaffney and Shutong Li and Daniel Ng and Anastasia Petrushkina and Niket Kumar and Adam Cogdell and Mridul Sahu and Yaning Liang and Nikhil Bansal and Justin Pan and Angel Mau and Abhishek Agrawal and Marco Berlot and Ruoxin Sang and Kiranbir Sodhia and Rakesh Iyer},
year={2026},
eprint={2605.23066},
archivePrefix={arXiv},
primaryClass={cs.DC},
url={https://arxiv.org/abs/2605.23066},
}
Orbax Checkpointing is used extensively across JAX machine learning frameworks and model implementations.
- Flax (Google's flexible and expressive neural network library for JAX)
- Gemma (Open foundation models by Google DeepMind)
- Kauldron (Google Research training and evaluation framework)
- PaxML (Google's high-performance framework for training large-scale JAX models)
- T5X (Google's JAX framework for high-performance sequence models)
- MaxText (Google's high-performance, scalable JAX LLM implementation)
- MaxDiffusion (Stable diffusion JAX training library optimized for Cloud TPUs)
- Tunix (Google's JAX-native library for LLM post-training)
- Numerous Google-internal ML frameworks