Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed vs DDP #19246

Open
jpatel-bdai opened this issue Jan 8, 2024 · 2 comments
Open

Deepspeed vs DDP #19246

jpatel-bdai opened this issue Jan 8, 2024 · 2 comments
Labels

Comments

@jpatel-bdai
Copy link

jpatel-bdai commented Jan 8, 2024

Bug description

It is expected that on a single GPU, DDP and Deepspeed strategies (i.e. deepspeed_stage_1, deepspeed_stage_2 and so on) should give the exact same loss values (if seed is fixed). I have a model that uses torch.nn.Parameter and the forward pass and gradient updates with these 2 strategies give different loss values as the training progresses. However, the model code is too big to share. I have this basic code where I change the strategies between deepspeed_stage_1 and ddp with different precision values (32 and 16), however I get different results when changing the strategies. Are there tests carried out to ensure deepspeed implementation matches ddp?

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision.datasets import MNIST
from torchvision import transforms
tmpdir = os.getcwd()
from lightning import Trainer, LightningModule, LightningDataModule
from lightning.pytorch.loggers.wandb import WandbLogger

PATH_DATASETS = os.environ.get('PATH_DATASETS', '.')
BATCH_SIZE = 256
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())

from lightning.pytorch import Trainer, seed_everything
seed_everything(42, workers=True)

class LitMNIST(LightningModule):

    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(), nn.Linear(channels * width * height, hidden_size), nn.ReLU(), nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, self.num_classes)
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)


class MyDataModule(LightningDataModule):
    def __init__(self, data_dir=PATH_DATASETS):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)


if __name__ == '__main__':
    import time
    timestr = time.strftime("%Y%m%d-%H%M%S")
    strategy_name = "deepspeed_stage_1"
    # wandb_logger = WandbLogger(project="test_model",id=f"test_32_{strategy_name}_{timestr}", log_model="all")
    
    model = LitMNIST()
    datamodule = MyDataModule()
    trainer = Trainer(
        devices=1,
        accelerator="cuda",
        max_epochs=10,
        precision=32,
        strategy=strategy_name,
        # logger=wandb_logger,
    )
    trainer.fit(model, datamodule)


### Error messages and logs

Error messages and logs here please

![image](https://github.com/Lightning-AI/pytorch-lightning/assets/135647158/984fafba-1e28-491d-a74a-3c8588afcbbf)
These contain runs at different times with 2 strategies (deepspeed_stage_1 and ddp) with different precision values. Nomenclature - float 32 precision - `test_precision_strategy_timestamp , float 16 precision - test_strategy_timestamp`.

### Environment

#- Lightning Component (e.g. Trainer, LightningModule):
#- PyTorch Lightning Version : 2.1.0
#- PyTorch Version: 2.1.0+cu121
#- Python version : Python 3.10.12
#- OS (e.g., Linux): Debian
#- CUDA/cuDNN version: NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2
#- GPU models and configuration: NVIDIA L4 (24GB GPU VRAM)
#- How you installed Lightning(conda, pip, source): pip install lightning


cc @awaelchli
@jpatel-bdai jpatel-bdai added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jan 8, 2024
@awaelchli
Copy link
Contributor

Hey @jpatel-bdai

Are there tests carried out to ensure deepspeed implementation matches ddp?

There are not. The deepspeed integration in Lightning has posed challenging to maintain, the deepspeed maintainers themselves don't write any tests for their software. It is unclear what the future of DeepSpeed is in Lightning.

In any case, I think one thing to check in your experiment is whether the modules are initialized the same (same random weights). And setting Trainer(deterministic=True) as a sanity check.

@awaelchli awaelchli added strategy: deepspeed and removed needs triage Waiting to be triaged by maintainers labels Jan 9, 2024
@jpatel-bdai
Copy link
Author

I verified that the modules are initialized with same weights and set deterministic=True in Trainer() but still looks like the DDP and Deepspeed loss values do not match on a single GPU. The issue I am facing currently is as below:
I have
tensor_x = torch.nn.Parameter(torch.zeros((dim_a, dim_b)))
in my model initialization and the following in the forward pass
tensor_x.data = torch.nn.functional.normalize(tensor_x.data, dim=-1)
During the forward pass few of the tensor_x.data values are dissimilar at the 3rd decimal (ex: 12.04345 and 12.04556) and majority of the values are exactly same. But, this impacts the performance of the model. As the training progress, the losses in case of Deepspeed do not go as low as DDP. This is with deepspeed_stage_1. Do you have any potential directions I could look into?

The deepspeed integration in Lightning has posed challenging to maintain, the deepspeed maintainers themselves don't write any tests for their software. It is unclear what the future of DeepSpeed is in Lightning.

In that case, what do you suggest the path moving forward? We ported our codebase to Lightning as it makes using Deepspeed and FSDP strategies easier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants