Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support AMP with TPUs #17927

Open
carmocca opened this issue Jun 26, 2023 · 0 comments
Open

Support AMP with TPUs #17927

carmocca opened this issue Jun 26, 2023 · 0 comments
Labels
fabric lightning.fabric.Fabric feature Is an improvement or enhancement precision: amp Automatic Mixed Precision strategy: xla trainer
Milestone

Comments

@carmocca
Copy link
Contributor

carmocca commented Jun 26, 2023

Description & Motivation

Lightning currently supports accelerator="tpu", precision="bf16-mixed", but so far, this just sets the XLA_USE_BF16 environment variable:

Side note: why does Fabric also move the data to bf16?

The XLA team added support for automatic mixed precision (AMP). XLA:GPU uses a GradScaler and the autocast context manager, whereas XLA:TPU just uses the latter: https://github.com/pytorch/xla/blob/c9f2d91a234cdaf91f0bbdb044ec94e297ac839a/test/test_train_mp_mnist_amp.py#L143-L147

Pitch

Integrate from torch_xla.amp import autocast, GradScaler

The code would be very similar to the non-XLA AMP plugin: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/amp.py

This would likely replace our existing XLABf16Precision plugin with an XLAMixedPrecision plugin.

Alternatives

This was just merged upstream. It's likely very experimental. I expect it will be released with PyTorch 2.1.

Additional context

PR on PyTorch: pytorch/pytorch#96370
PR on XLA: pytorch/xla#5161

cc @Borda @carmocca @justusschock @awaelchli @JackCaoG @steventk-g @Liyang90

@carmocca carmocca added feature Is an improvement or enhancement fabric lightning.fabric.Fabric precision: amp Automatic Mixed Precision trainer strategy: xla labels Jun 26, 2023
@carmocca carmocca added this to the future milestone Jun 26, 2023
@carmocca carmocca changed the title Support autocast with TPUs Support AMP with TPUs Jun 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric feature Is an improvement or enhancement precision: amp Automatic Mixed Precision strategy: xla trainer
Projects
None yet
Development

No branches or pull requests

1 participant