Skip to content

Deepspeed vs DDP  #19246

Open
Open
@jpatel-bdai

Description

@jpatel-bdai

Bug description

It is expected that on a single GPU, DDP and Deepspeed strategies (i.e. deepspeed_stage_1, deepspeed_stage_2 and so on) should give the exact same loss values (if seed is fixed). I have a model that uses torch.nn.Parameter and the forward pass and gradient updates with these 2 strategies give different loss values as the training progresses. However, the model code is too big to share. I have this basic code where I change the strategies between deepspeed_stage_1 and ddp with different precision values (32 and 16), however I get different results when changing the strategies. Are there tests carried out to ensure deepspeed implementation matches ddp?

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision.datasets import MNIST
from torchvision import transforms
tmpdir = os.getcwd()
from lightning import Trainer, LightningModule, LightningDataModule
from lightning.pytorch.loggers.wandb import WandbLogger

PATH_DATASETS = os.environ.get('PATH_DATASETS', '.')
BATCH_SIZE = 256
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())

from lightning.pytorch import Trainer, seed_everything
seed_everything(42, workers=True)

class LitMNIST(LightningModule):

    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(), nn.Linear(channels * width * height, hidden_size), nn.ReLU(), nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, self.num_classes)
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)


class MyDataModule(LightningDataModule):
    def __init__(self, data_dir=PATH_DATASETS):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)


if __name__ == '__main__':
    import time
    timestr = time.strftime("%Y%m%d-%H%M%S")
    strategy_name = "deepspeed_stage_1"
    # wandb_logger = WandbLogger(project="test_model",id=f"test_32_{strategy_name}_{timestr}", log_model="all")
    
    model = LitMNIST()
    datamodule = MyDataModule()
    trainer = Trainer(
        devices=1,
        accelerator="cuda",
        max_epochs=10,
        precision=32,
        strategy=strategy_name,
        # logger=wandb_logger,
    )
    trainer.fit(model, datamodule)


### Error messages and logs

Error messages and logs here please

![image](https://github.com/Lightning-AI/pytorch-lightning/assets/135647158/984fafba-1e28-491d-a74a-3c8588afcbbf)
These contain runs at different times with 2 strategies (deepspeed_stage_1 and ddp) with different precision values. Nomenclature - float 32 precision - `test_precision_strategy_timestamp , float 16 precision - test_strategy_timestamp`.

### Environment

#- Lightning Component (e.g. Trainer, LightningModule):
#- PyTorch Lightning Version : 2.1.0
#- PyTorch Version: 2.1.0+cu121
#- Python version : Python 3.10.12
#- OS (e.g., Linux): Debian
#- CUDA/cuDNN version: NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2
#- GPU models and configuration: NVIDIA L4 (24GB GPU VRAM)
#- How you installed Lightning(conda, pip, source): pip install lightning


cc @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions