Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Lightning FSDP takes more memory than PyTorch FSDP #19721

Open
anandhperumal opened this issue Apr 1, 2024 · 6 comments · May be fixed by #20323
Open

PyTorch Lightning FSDP takes more memory than PyTorch FSDP #19721

anandhperumal opened this issue Apr 1, 2024 · 6 comments · May be fixed by #20323
Labels
question Further information is requested strategy: fsdp Fully Sharded Data Parallel

Comments

@anandhperumal
Copy link

anandhperumal commented Apr 1, 2024

Bug description

The Pytorch Lightining is taking more memory than Pytorch FSDP.
I'm able to train the gemma-2b model but it takes 3 times more memory.

For openchat it goes out of memory.
Please let me know if I'm missing anything.

I'm using A100 8 * 80 GB.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from custom_dataset import SupervisedDataset
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
import transformers

class LanguageModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = None

    def training_step(self, batch):
        outputs = self.model(**batch)
        loss = outputs[0]
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1)

    def configure_model(self):
        if self.model is not None:
            return
        self.model = transformers.AutoModelForCausalLM.from_pretrained(
            'openchat/openchat_3.5',
            torch_dtype=torch.bfloat16
            
        )

L.seed_everything(42)
tokenizer = transformers.AutoTokenizer.from_pretrained('openchat/openchat_3.5')
tokenizer.pad_token = tokenizer.eos_token
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=f'./train.csv', mode='train')
eval_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=f'./validation.csv', mode='validation')
test_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=f'./test.csv', mode='test')

train_dataloader = DataLoader(train_dataset, batch_size=1)

model = LanguageModel()
sharding_strategy = {}

policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
sharding_strategy['sharding_strategy'] = "FULL_SHARD"
sharding_strategy['auto_wrap_policy'] = policy
sharding_strategy['state_dict_type'] = "full"
sharding_strategy['limit_all_gathers'] = True
sharding_strategy['cpu_offload'] = True

strategy = FSDPStrategy(
    **sharding_strategy
)

trainer = L.Trainer(accelerator="cuda",  strategy=strategy, precision=16, max_epochs=1)
trainer.fit(model, train_dataloader)
# trainer.print(torch.cuda.memory_summary())
trainer.save_checkpoint("path/to/checkpoint/file")

For Pytorch FSDP Code : https://github.com/AnswerDotAI/fsdp_qlora/blob/main/train.py
For pytorch FSDP : I'm using use_gradient_checkpointing: True, use_activation_cpu_offload False, use_cpu_offload False.

The context size is the same for both.

Error messages and logs

Cuda Out of Memory

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 2.2.1
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0): 2.2.1
#- Python version (e.g., 3.9): 3.9
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 12.1.10
#- GPU models and configuration: A100 (8 * 80GB)
#- How you installed Lightning(`conda`, `pip`, source): pip
#- Running environment of LightningApp (e.g. local, cloud): local

More info

No response

cc @awaelchli @carmocca

@anandhperumal anandhperumal added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 1, 2024
@awaelchli
Copy link
Contributor

awaelchli commented Apr 1, 2024

Hi @anandhperumal

The reference implementation is using LoRA, but I don't see this configured anywhere in your code snippet. This will make a very big difference in memory consumption. Furthermore, you didn't enable activation checkpointing in FSDP in the code above, but you reported doing so in the reference implemenatation, which is will have another big impact. Please check again. It's important to compare equivalent settings.

If you'd like to try out LoRA with Lightning, we have an implementation here (and docs).

@awaelchli awaelchli added question Further information is requested strategy: fsdp Fully Sharded Data Parallel and removed bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 1, 2024
@anandhperumal
Copy link
Author

@awaelchli so reference code also has an option "full" which train the entire model. I'm using full option.

python -u src/train.py --world_size 8 --master_port 12356 --model_name openchat/openchat-3.5 --gradient_accumulation_steps 4 --batch_size 1 --precision bf16 --train_type full --use_gradient_checkpointing false --save_model true  --use_activation_cpu_offload false --use_cpu_offload false --num_epochs 3 --lr_scheduler linear

Also, I updated the lighting activation checkpoint policy, yet no difference.
Infact, I even enabled cpu_offload for pytorch lightining and not for pytorch .

sharding_strategy['activation_checkpointing_policy'] = policy

For your reference, see the below image for openchat memory consumption for pytorch FSDP code whereas lightinig doesn't even run for 1 step during training. Please let me know if I'm missing anything.

image

@awaelchli
Copy link
Contributor

awaelchli commented Apr 2, 2024

Thanks. This is a very important detail that changes everything. But there are still many differneces between the code that you shared and the reference.

I see you are specifying --precision bf16 which in Lightning the equivalent would be precision="bf16-true" (not 16 like you show). Also, the reference implementation truncates sequences to --context_length 512 by default, but this is nowhere specified in your code. Can you please share the full code you are running (replacing data with dummy random data) so that we can reproduce what you are seeing? Also I encourage you to compare with our LitGPT implementation I shared earlier.

@awaelchli awaelchli changed the title Pytorch Lightining FSDP takes more memory than Pytorch FSDP PyTorch Lightning FSDP takes more memory than PyTorch FSDP Apr 2, 2024
@awaelchli
Copy link
Contributor

awaelchli commented Apr 2, 2024

Another bug in the code is policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}, but you are using the openchat/openchat_3.5 from HuggingFace which doesn't have these layers, so they won't be wrapped. The reference implementation has it defined here: https://github.com/AnswerDotAI/fsdp_qlora/blob/3afe102048754c3fc499624511ab1a5ccd6ee45d/train.py#L441-L444. You should use the same.

@anandhperumal
Copy link
Author

anandhperumal commented Apr 3, 2024

Thank you so much for getting back so quickly.

Okay, I created sample script for you.
to run answer_ai.py

python answer_ai.py --world_size 8 --master_port 12356 --model_name openchat/openchat_3.5 --gradient_accumulation_steps 4 --batch_size 1 --precision bf16 --train_type full --use_gradient_checkpointing false --save_model true --use_activation_cpu_offload false --use_cpu_offload false --num_epochs 1

image

for pytorch lightining : you can directly run it no need to pass any parameters.

There are two issues:
without precision="bf16-true" : Recomputed values for the following tensors have different metadata than during the forward pass. saved metadata: {'shape': torch.Size([1, 32, 228, 128]), 'dtype': torch.float32, 'device': device(type='cuda', index=3)} recomputed metadata: {'shape': torch.Size([1, 32, 456, 128]), 'dtype': torch.float32, 'device': device(type='cuda', index=3)}

and with precision="bf16-true":

saved metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cpu')} recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float32, 'device': device(type='cpu')}

script.zip

@function2-llx
Copy link
Contributor

function2-llx commented May 8, 2024

@awaelchli I think I find the bug. I don't find convert_module defined for FSDPPrecision.


And FSDPPrecision.convert_module will finally fallback to convert_module of lightning.fabric.plugins.Precision, which simply does nothing:

def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
"""
return module

@tshu-w tshu-w linked a pull request Oct 6, 2024 that will close this issue
9 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants