Skip to content

Commit dee63cc

Browse files
authored
Merge pull request #6096 from BurkeHulk/hotfix/lora_ckpt
[hotfix] fix lora ckpt saving format
2 parents 19baab5 + 6d6cafa commit dee63cc

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,11 @@ def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
290290
assert isinstance(
291291
peft_model, PeftModel
292292
), "The model doesn't have lora adapters, please enable lora before saving."
293-
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
293+
return peft_model.save_pretrained(
294+
checkpoint,
295+
safe_serialization=use_safetensors,
296+
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
297+
)
294298

295299

296300
class LowLevelZeroPlugin(DPPluginBase):

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
22

3+
import torch
34
import torch.nn as nn
45
from torch.nn.parallel import DistributedDataParallel as DDP
56
from torch.optim import Optimizer
67
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
8+
from torch.utils._pytree import tree_map
79
from torch.utils.data import DataLoader
810

911
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
@@ -134,7 +136,11 @@ def save_lora_as_pretrained(
134136
assert isinstance(
135137
peft_model, PeftModel
136138
), "The model doesn't have lora adapters, please enable lora before saving."
137-
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
139+
return peft_model.save_pretrained(
140+
checkpoint,
141+
safe_serialization=use_safetensors,
142+
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
143+
)
138144

139145

140146
class TorchDDPModel(ModelWrapper):

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn as nn
1212
from torch.distributed import ProcessGroup
1313
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
14+
from torch.utils._pytree import tree_map
1415

1516
from colossalai.cluster import DistCoordinator
1617
from colossalai.interface import ModelWrapper, OptimizerWrapper
@@ -956,4 +957,8 @@ def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
956957
assert isinstance(
957958
peft_model, PeftModel
958959
), "The model doesn't have lora adapters, please enable lora before saving."
959-
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
960+
return peft_model.save_pretrained(
961+
checkpoint,
962+
safe_serialization=use_safetensors,
963+
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
964+
)

0 commit comments

Comments
 (0)