Introduction | Inference | Training | Installation |Composability | Custom Kernels | Prototype Features | Integrations | Videos | License | Citation
torchao: PyTorch library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training.
From the team that brought you the fast series
- 9.5x speedups for Image segmentation models with sam-fast
- 10x speedups for Language models with gpt-fast
- 3x speedup for Diffusion models with sd-fast
torchao just works with torch.compile()
and FSDP2
over most PyTorch models on Huggingface out of the box.
Our optimizations deliver significant speedups and memory savings:
- INT4 Weight-Only Quantization: 2x higher throughput (201 vs 107 tokens/sec) with 65% less memory (4.9GB vs 13.9GB) on LLaMA-2-7B
- Float8 Dynamic Quantization: Demonstrates 53.88% speedup on Flux.1-Dev* and 27.33% speedup on CogVideoX-5b on H100 GPU while preserving image quality
- INT4 + 2:4 Sparsity: 2.4x throughput increase (226 vs 95 tokens/sec) with 80% memory reduction (5.3GB vs 16.4GB peak memory) on LLaMA-3-8B
For detailed benchmarks across models and techniques, see our quantization documentation.
Quantizing and Sparsifying your models is a 1 liner that should work on any model with an nn.Linear
including your favorite HuggingFace model. You can find a more comprehensive usage instructions here, sparsity here and a HuggingFace inference example here
For inference, we have the option of
- Quantize only the weights: works best for memory bound models
- Quantize the weights and activations: works best for compute bound models
- Quantize the activations and weights and sparsify the weight
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int8_weight,
float8_dynamic_activation_float8_weight,
int4_weight_only,
)
quantize_(m, int4_weight_only())
For gpt-fast int4_weight_only()
is the best option at bs=1 as it 2x the tok/s and reduces the VRAM requirements by about 65% over a torch.compiled baseline.
If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so quantize_(model, int8_weight_only(), device="cuda")
which will send and quantize each layer individually to your GPU.
If you see slowdowns with any of these techniques or you're unsure which option to use, consider using autoquant which will automatically profile layers and pick the best way to quantize each layer.
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
We also provide a developer facing API so you can implement your own quantization algorithms so please use the excellent HQQ algorithm as a motivating example.
We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference.
In practice these features alongside int4 weight only quantization allow us to reduce peak memory by ~55%, meaning we can Llama3.1-8B inference with a 130k context length with only 18.9 GB of peak memory. More details can be found here
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering 96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe here
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
# Insert fake quantization
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
my_model,
intx_quantization_aware_training(activation_config, weight_config),
)
# Run training... (not shown)
# Convert fake quantization to actual quantized operations
quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.
With torch.compile
on, current results show throughput speedups of up to 1.5x on 128 H100 GPU LLaMa 3 70B pretraining jobs (details)
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(m, module_filter_fn=...)
And for an end-to-minimal training recipe of pretraining with float8, you can check out torchtitan
We've added support for semi-structured 2:4 sparsity with 6% end-to-end speedups on ViT-L. Full blog here
The code change is a 1 liner with the full example available here
from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear
swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})
ADAM takes 2x as much memory as the model params so we can quantize the optimizer state to either 8 or 4 bit effectively reducing the optimizer VRAM requirements by 2x or 4x respectively over an fp16 baseline
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a few hundred lines of PyTorch code and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks here
We also have support for single GPU CPU offloading where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can reduce your VRAM requirements by 60%
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])
torchao
makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch, see getting started for more details.
Install the stable release (recommended):
pip install torchao
Other options:
# Nightly build
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu124
# Different CUDA versions
pip install torchao --index-url https://download.pytorch.org/whl/cu118 # CUDA 11.8
pip install torchao --index-url https://download.pytorch.org/whl/cpu # CPU only
USE_CPP=0 python setup.py develop # Skip C++/CUDA extensions
torch.compile
: A key design principle for us is composability - any custom dtype or memory layout should work with our compiler. We enable kernel implementations in PyTorch, CUDA, C++, or Triton. This allows researchers and engineers to start with high-level dtype and layout logic in pure PyTorch, then progressively optimize performance by implementing lower-level kernels as needed, while maintaining compatibility with the compile infrastructure.
FSDP2: Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization.
The best example we have combining the composability of lower bit dtype with compile and fsdp is NF4 which we used to implement the QLoRA algorithm. So if you're doing research at the intersection of this area we'd love to hear from you.
Our framework makes it straightforward to add tensor parallel support to your custom quantized tensor subclass. Check out our tensor parallel tutorial to see how a quantized tensor subclass can be extended to support column and row-wise tensor sharding while maintaining compatibility with torch.compile
.
We've added support for authoring and releasing custom ops that do not graph break with torch.compile()
. We have a few examples you can follow
- fp6 for 2x faster inference over fp16 with an easy to use API
quantize_(model, fpx_weight_only(3, 2))
- 2:4 Sparse Marlin GEMM 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
- int4 tinygemm unpacker which makes it easier to switch quantized backends for inference
If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on this issue or feel free to contribute directly to the repo.
Check out our prototype directory where we experiment with cutting-edge model optimization techniques for both training and inference. If you're interested in contributing experimental research or want to explore feel free to open an issue or PR.
We're also fortunate to be integrated into some of the leading open-source libraries including
- Hugging Face transformers with a builtin inference backend and low bit optimizers
- Hugging Face diffusers best practices with torch.compile and torchao in a standalone repo diffusers-torchao
- Mobius HQQ backend leveraged our int4 kernels to get 195 tok/s on a 4090
- TorchTune for our QLoRA and QAT recipes
- torchchat for post training quantization
- SGLang for LLM serving: usage and the major PR.
- Keynote talk at GPU MODE IRL
- Low precision dtypes at PyTorch conference
- Slaying OOMs at the Mastering LLM's course
- Advanced Quantization at CUDA MODE
- Chip Huyen's GPU Optimization Workshop
- Cohere for AI community talk
torchao
is released under the BSD 3 license.
If you find the torchao library useful, please cite it in your work as below.
@software{torchao,
title = {torchao: PyTorch native quantization and sparsity for training and inference},
author = {torchao maintainers and contributors},
url = {https://github.com/pytorch/torchao},
license = {BSD-3-Clause},
month = oct,
year = {2024}
}