-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrating pytorch XLA when using multiple GPUs #16130
Comments
Let's do it! To be clear, this would be enabled with: |
@carmocca I assume we could reuse a lot of our current xla-strategy for tpus. |
That would be part of the goal |
I like it, and I think it won't even be that hard! The abstraction of strategy and accelerator are already in place and are meant to support exactly this kind of relationship between a communication layer (xla) and accelerator (gpu/tpu). |
This is great! 🐰 |
Hello,this is very wonderful work! I want to know when we can finish it that Trainer(accelerator='cuda'|'gpu', strategy='xla') can work normally. |
@qipengh We haven't started working on it. The feature is up for grabs if you or anyone from the community has interest in contributing and testing it out. |
This should become very easy once we add support for XLA's PJRT runtime: https://github.com/pytorch/xla/blob/master/docs/pjrt.md#gpu |
In addition, we need to land
|
Is there an example model on how to use XLA with (a single) CUDA GPU? The link above now 404s since it was posted, I am struggling to find one anywhere; currently everything I come across is for TPUs only. Roughly how much work do folks think is still needed in order to implement this FR? |
Description & Motivation
I've experienced with pytorch XLA using multitple NVIDIA A100 GPU and I observed that in most cases training is faster. So it would be really nice to have the option to use XLA for training in pytorch lightning.
The main advantage is faster training.
Additional context
Here is a code link : https://github.com/Dhouib-med/Test-XLA/blob/17e5b6bd6c77fffa67818462856277a57877ff3b/test_xla.py to train a simple CNN on the MNIST dataset using XLA (on 2 GPUS). The main parts where taken from https://github.com/pytorch/xla.
This wheel needs to be installed along with adequate pytorch and torchvision versions (1.11 and 0.14) https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp37-cp37m-linux_x86_64.whl
@justusschock
cc @Borda @justusschock @awaelchli @carmocca @JackCaoG @steventk-g @Liyang90
The text was updated successfully, but these errors were encountered: