Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trainer.save_checkpoint doesn't work after trainer.test with deepspeed strategy #15247

Open
rohitgr7 opened this issue Oct 22, 2022 · 0 comments

Comments

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 22, 2022

Bug description

Reported here: #14944 (comment)

Reason? Read the thread: #14944 (comment)

in short

trainer.fit()
trainer.test()
trainer.save_checkpoint()

does not work.

Either we need to update the strategy somehow or improve the support in the deepspeed package itself to allow saving the checkpoint without any optimizer.

Full repro:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.strategies import DeepSpeedStrategy


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()

    def configure_sharded_model(self):
        self.layer = torch.nn.Sequential(
            torch.nn.Linear(32, 10000),
            torch.nn.Linear(10000, 1000),
            torch.nn.Linear(1000, 2),
        )

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=2,
        enable_model_summary=False,
        devices=2,
        accelerator="cuda",
        strategy="deepspeed_stage_3",
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)
    trainer.save_checkpoint("fit_test_checkpoint.ckpt")


if __name__ == "__main__":
    run()

Environment


#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

Issue on DeepSpeed GitHub: microsoft/DeepSpeed#3601

@rohitgr7 rohitgr7 added needs triage Waiting to be triaged by maintainers checkpointing Related to checkpointing strategy: deepspeed trainer: validate and removed needs triage Waiting to be triaged by maintainers labels Oct 22, 2022
@carmocca carmocca added this to the future milestone Oct 24, 2022
@awaelchli awaelchli self-assigned this Mar 18, 2023
@awaelchli awaelchli removed their assignment Nov 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants