Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to chnage checkpoint in on_save_checkpoint with Deepspeed #18747

Open
xluo233 opened this issue Oct 8, 2023 · 5 comments
Open

Unable to chnage checkpoint in on_save_checkpoint with Deepspeed #18747

xluo233 opened this issue Oct 8, 2023 · 5 comments
Labels
bug Something isn't working checkpointing Related to checkpointing strategy: deepspeed ver: 2.0.x
Milestone

Comments

@xluo233
Copy link

xluo233 commented Oct 8, 2023

Bug description

When using DeepSpeed, the changes of checkpoint (add/remove key) in on_save_checkpoint are not being preserved. Switching strategy to ddp, the changes are saved as expected.

Environment

  • Lightning: 2.0.3
  • Deepspeed: 0.10.1

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from lightning.pytorch.utilities.deepspeed import ds_checkpoint_dir

from deepspeed.utils.zero_to_fp32 import (
    get_fp32_state_dict_from_zero_checkpoint,
    get_model_state_file,
    get_optim_files,
)


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self, remove_key=None):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.remove_key = remove_key

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def on_save_checkpoint(self, checkpoint) -> None:
        if self.remove_key is not None:
            print(f"{checkpoint['state_dict'].keys()} in state_dict ")
            self.remove_params(checkpoint, key=self.remove_key)
            print(f"{checkpoint['state_dict'].keys()} in state_dict after remove params\n")


    def remove_params(self, checkpoint, key) -> None:
        del_keys = []

        for k in checkpoint["state_dict"]:
            if key in k:
                del_keys.append(k)
        print(f"{len(del_keys)} keys to remove")

        for k in del_keys:
            checkpoint["state_dict"].pop(k)



def run(key, strategy):
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel(key)

    trainer = Trainer(
        strategy=strategy,
        accelerator='gpu',
        devices=1,
        callbacks=[ModelCheckpoint(save_top_k=1, save_last=True)],
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=10,
        enable_model_summary=False,
    )

    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)



def load_ckpt(checkpoint_dir, tag=None, map_location="cpu"):
    CPU_DEVICE = map_location
    state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)

    # additional logic to ensure we keep the lightning state dict as well from rank 0.
    deepspeed_states = [
        "module",
        "optimizer",
        "lr_scheduler",
        "csr_tensor_module_names",
        "skipped_steps",
        "global_steps",
        "dp_world_size",
        "mp_world_size",
    ]
    checkpoint_dir = ds_checkpoint_dir(checkpoint_dir)
    optim_files = get_optim_files(checkpoint_dir)
    optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
    zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
    model_file = get_model_state_file(checkpoint_dir, zero_stage)
    client_state = torch.load(model_file, map_location=CPU_DEVICE)
    client_state = {
        key: value
        for key, value in client_state.items()
        if key not in deepspeed_states
    }

    state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict}
    client_state["state_dict"] = state_dict

    return client_state


if __name__ == "__main__":
    run(None, 'deepspeed')
    run("layer.weight", 'deepspeed')
    run(None, 'ddp')
    run("layer.weight", 'ddp')

    ckpt_v0 = load_ckpt('./lightning_logs/version_0/checkpoints/last.ckpt')
    print(f"keys in ckpt_v0 {ckpt_v0['state_dict'].keys()} \n")
    
    ckpt_v1 = load_ckpt('./lightning_logs/version_1/checkpoints/last.ckpt')
    print(f"keys in ckpt_v1 {ckpt_v1['state_dict'].keys()} \n")

    ckpt_v2 = torch.load('./lightning_logs/version_2/checkpoints/last.ckpt')
    print(f"keys in ckpt_v2 {ckpt_v2['state_dict'].keys()} \n")

    ckpt_v3 = torch.load('./lightning_logs/version_3/checkpoints/last.ckpt')
    print(f"keys in ckpt_v3 {ckpt_v3['state_dict'].keys()} \n")

Error messages and logs

Processing zero checkpoint './lightning_logs/version_0/checkpoints/last.ckpt/checkpoint'
Detected checkpoint of type zero stage 2, world_size: 1
Parsing checkpoint created by deepspeed==0.10.1
Reconstructed fp32 state dict with 2 params 66 elements
keys in ckpt_v0 dict_keys(['layer.weight', 'layer.bias']) 

Processing zero checkpoint './lightning_logs/version_1/checkpoints/last.ckpt/checkpoint'
Detected checkpoint of type zero stage 2, world_size: 1
Parsing checkpoint created by deepspeed==0.10.1
Reconstructed fp32 state dict with 2 params 66 elements
keys in ckpt_v1 dict_keys(['layer.weight', 'layer.bias']) 

keys in ckpt_v2 odict_keys(['layer.weight', 'layer.bias']) 

keys in ckpt_v3 odict_keys(['layer.bias'])

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @awaelchli

@xluo233 xluo233 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 8, 2023
@cyr0930
Copy link

cyr0930 commented Nov 6, 2023

I have the same issue!

@carmocca carmocca added checkpointing Related to checkpointing strategy: deepspeed and removed needs triage Waiting to be triaged by maintainers labels Nov 21, 2023
@carmocca carmocca added this to the 2.2 milestone Nov 21, 2023
@carmocca
Copy link
Contributor

carmocca commented Nov 21, 2023

Unfortunately, deepspeed 0.10 is not supported yet. Our testing is currently pinned to https://github.com/Lightning-AI/lightning/blob/master/requirements/pytorch/strategies.txt#L6

@carmocca carmocca modified the milestones: 2.2, future Nov 21, 2023
@xluo233
Copy link
Author

xluo233 commented Dec 2, 2023

Unfortunately, deepspeed 0.10 is not supported yet. Our testing is currently pinned to https://github.com/Lightning-AI/lightning/blob/master/requirements/pytorch/strategies.txt#L6

I setup the environment for the compatible version:

conda create -n deepspeed_env python=3.8
conda activate deepspeed_env

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

pip install deepspeed==0.9.3

And modify the code accordingly:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from lightning.pytorch.utilities.deepspeed import ds_checkpoint_dir

from deepspeed.utils.zero_to_fp32 import (
    get_fp32_state_dict_from_zero_checkpoint,
    get_model_state_file,
    get_optim_files,
)


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self, remove_key=None):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.remove_key = remove_key

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def on_save_checkpoint(self, checkpoint) -> None:
        if self.remove_key is not None:
            print(f"{checkpoint['state_dict'].keys()} in state_dict ")
            self.remove_params(checkpoint, key=self.remove_key)
            print(f"{checkpoint['state_dict'].keys()} in state_dict after remove params\n")


    def remove_params(self, checkpoint, key) -> None:
        del_keys = []

        for k in checkpoint["state_dict"]:
            if key in k:
                del_keys.append(k)
        print(f"{len(del_keys)} keys to remove")

        for k in del_keys:
            checkpoint["state_dict"].pop(k)



def run(key, strategy):
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel(key)

    trainer = Trainer(
        strategy=strategy,
        accelerator='gpu',
        devices=1,
        callbacks=[ModelCheckpoint(save_top_k=1, save_last=True)],
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=10,
        enable_model_summary=False,
    )

    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)



def load_ckpt(checkpoint_dir, tag=None, map_location="cpu"):
    CPU_DEVICE = map_location
    state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)

    # additional logic to ensure we keep the lightning state dict as well from rank 0.
    deepspeed_states = [
        "module",
        "optimizer",
        "lr_scheduler",
        "csr_tensor_module_names",
        "skipped_steps",
        "global_steps",
        "dp_world_size",
        "mp_world_size",
    ]
    checkpoint_dir = ds_checkpoint_dir(checkpoint_dir)
    optim_files = get_optim_files(checkpoint_dir)
    optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
    zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
    model_file = get_model_state_file(checkpoint_dir, zero_stage)
    client_state = torch.load(model_file, map_location=CPU_DEVICE)
    client_state = {
        key: value
        for key, value in client_state.items()
        if key not in deepspeed_states
    }

    #state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict}
    client_state["state_dict"] = state_dict

    return client_state


if __name__ == "__main__":
    run(None, 'deepspeed')
    run("layer.weight", 'deepspeed')
    run(None, 'ddp')
    run("layer.weight", 'ddp')

    ckpt_v0 = load_ckpt('./lightning_logs/version_0/checkpoints/last.ckpt')
    
    
    ckpt_v1 = load_ckpt('./lightning_logs/version_1/checkpoints/last.ckpt')
    

    ckpt_v2 = torch.load('./lightning_logs/version_2/checkpoints/last.ckpt')
   

    ckpt_v3 = torch.load('./lightning_logs/version_3/checkpoints/last.ckpt')
    


    print(f"ckpt_v0 deepspeed ckpt no parameter removed, \nexpected key layer.weight, layer.bias \ngot  {list(ckpt_v0['state_dict'].keys())} \n\n")
    print(f"ckpt_v1 deepspeed ckpt layer.weight removed, \nexpected key layer.bias \ngot {list(ckpt_v1['state_dict'].keys())} \n\n")
    print(f"ckpt_v2 ddp ckpt no parameter removed, \nexpected key layer.weight, layer.bias \ngot  {list(ckpt_v2['state_dict'].keys())} \n\n")
    print(f"ckpt_v3 ddp ckpt layer.weight removed, \nexpected key layer.bias \ngot {list(ckpt_v3['state_dict'].keys())} \n\n")

The output of script:

ckpt_v0 deepspeed ckpt no parameter removed, 
expected key layer.weight, layer.bias 
got  ['layer.weight', 'layer.bias'] 


ckpt_v1 deepspeed ckpt layer.weight removed, 
expected key layer.bias 
got ['layer.weight', 'layer.bias'] 


ckpt_v2 ddp ckpt no parameter removed, 
expected key layer.weight, layer.bias 
got  ['layer.weight', 'layer.bias'] 


ckpt_v3 ddp ckpt layer.weight removed, 
expected key layer.bias 
got ['layer.bias'] 

The change of parameter is not saved by deepspeed.

@ForJadeForest
Copy link

Is there any solution? Saving a freeze LLM in checkpoint is too large and slow.

@cyr0930
Copy link

cyr0930 commented Oct 28, 2024

any improvement?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing strategy: deepspeed ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

4 participants