-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch Lightning FSDP takes more memory than PyTorch FSDP #19721
Comments
The reference implementation is using LoRA, but I don't see this configured anywhere in your code snippet. This will make a very big difference in memory consumption. Furthermore, you didn't enable activation checkpointing in FSDP in the code above, but you reported doing so in the reference implemenatation, which is will have another big impact. Please check again. It's important to compare equivalent settings. If you'd like to try out LoRA with Lightning, we have an implementation here (and docs). |
@awaelchli so reference code also has an option "full" which train the entire model. I'm using full option.
Also, I updated the lighting activation checkpoint policy, yet no difference. sharding_strategy['activation_checkpointing_policy'] = policy For your reference, see the below image for openchat memory consumption for pytorch FSDP code whereas lightinig doesn't even run for 1 step during training. Please let me know if I'm missing anything. |
Thanks. This is a very important detail that changes everything. But there are still many differneces between the code that you shared and the reference. I see you are specifying |
Another bug in the code is |
@awaelchli I think I find the bug. I don't find
And FSDPPrecision.convert_module will finally fallback to convert_module of lightning.fabric.plugins.Precision , which simply does nothing:
pytorch-lightning/src/lightning/fabric/plugins/precision/precision.py Lines 48 to 54 in 0c8a193
|
Bug description
The Pytorch Lightining is taking more memory than Pytorch FSDP.
I'm able to train the gemma-2b model but it takes 3 times more memory.
For openchat it goes out of memory.
Please let me know if I'm missing anything.
I'm using A100 8 * 80 GB.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
For Pytorch FSDP Code : https://github.com/AnswerDotAI/fsdp_qlora/blob/main/train.py
For pytorch FSDP : I'm using use_gradient_checkpointing: True, use_activation_cpu_offload False, use_cpu_offload False.
The context size is the same for both.
Error messages and logs
Environment
Current environment
More info
No response
cc @awaelchli @carmocca
The text was updated successfully, but these errors were encountered: