Support AMP with TPUs #17927
Labels
fabric
lightning.fabric.Fabric
feature
Is an improvement or enhancement
precision: amp
Automatic Mixed Precision
strategy: xla
trainer
Milestone
Description & Motivation
Lightning currently supports
accelerator="tpu", precision="bf16-mixed"
, but so far, this just sets theXLA_USE_BF16
environment variable:The XLA team added support for automatic mixed precision (AMP).
XLA:GPU
uses aGradScaler
and theautocast
context manager, whereasXLA:TPU
just uses the latter: https://github.com/pytorch/xla/blob/c9f2d91a234cdaf91f0bbdb044ec94e297ac839a/test/test_train_mp_mnist_amp.py#L143-L147Pitch
Integrate
from torch_xla.amp import autocast, GradScaler
The code would be very similar to the non-XLA AMP plugin: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/amp.py
This would likely replace our existing
XLABf16Precision
plugin with anXLAMixedPrecision
plugin.Alternatives
This was just merged upstream. It's likely very experimental. I expect it will be released with PyTorch 2.1.
Additional context
PR on PyTorch: pytorch/pytorch#96370
PR on XLA: pytorch/xla#5161
cc @Borda @carmocca @justusschock @awaelchli @JackCaoG @steventk-g @Liyang90
The text was updated successfully, but these errors were encountered: