1+ import enum
12import logging
23import os
4+ import warnings
35from functools import partial
46from pathlib import Path
57from types import MethodType
68from typing import Callable , Dict , Iterator , List , Optional , Tuple
79
810import torch
911import torch .nn as nn
12+ from torch .nn import Parameter
1013from torch .optim import Optimizer
1114from torch .optim .lr_scheduler import _LRScheduler as LRScheduler
1215from torch .utils ._pytree import tree_map
@@ -42,6 +45,12 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
4245SUPPORTED_PRECISION = ["fp16" , "bf16" , "fp32" ]
4346
4447
48+ class OptimizerParamCheckState (enum .Enum ):
49+ ORIGIN_PARAM_FINDED = 0
50+ ORIGIN_PARAM_NOT_FIND = - 1
51+ LORA_PARM_EXISTED = - 2
52+
53+
4554class LowLevelZeroModel (ModelWrapper , AMPModelMixin ):
4655 def __init__ (self , module : nn .Module , precision : str ) -> None :
4756 super ().__init__ (module )
@@ -209,6 +218,19 @@ def load_sharded_model(
209218 super ().load_sharded_model (model , checkpoint_index_file , strict , use_safetensors , load_sub_module )
210219 model .update_master_params ()
211220
221+ def save_lora_as_pretrained (self , model , checkpoint , use_safetensors ):
222+ if os .path .isfile (checkpoint ):
223+ logging .error (f"Provided path ({ checkpoint } ) should be a directory, not a file" )
224+ return
225+ from peft import PeftModel
226+
227+ assert isinstance (model , ModelWrapper ), "Please boost the model before saving!"
228+ peft_model = model .unwrap ()
229+ assert isinstance (
230+ peft_model , PeftModel
231+ ), "The model doesn't have lora adapters, please enable lora before saving."
232+ return peft_model .save_pretrained (checkpoint , safe_serialization = use_safetensors )
233+
212234
213235class LowLevelZeroPlugin (DPPluginBase ):
214236 """
@@ -288,6 +310,7 @@ def __init__(
288310 cpu_offload = cpu_offload ,
289311 master_weights = master_weights ,
290312 )
313+ self .lora_enabled = False
291314 self .verbose = verbose
292315
293316 # set class name with stage, for better error message
@@ -311,6 +334,72 @@ def control_device(self) -> bool:
311334 def supported_devices (self ) -> List [str ]:
312335 return ["cuda" , "npu" ]
313336
337+ def support_lora (self ) -> bool :
338+ return True
339+
340+ def enable_lora (
341+ self , model : nn .Module , pretrained_dir : Optional [str ] = None , lora_config : Optional [Dict ] = None
342+ ) -> nn .Module :
343+ from peft import PeftModel , get_peft_model
344+
345+ assert not isinstance (model , LowLevelZeroModel ), "Lora should be enabled before boosting the model."
346+ self .lora_enabled = True
347+ warnings .warn ("You have enabled LoRa training. Please check the hyperparameters such as lr" )
348+
349+ if pretrained_dir is None :
350+ peft_model = get_peft_model (model , lora_config )
351+ else :
352+ peft_model = PeftModel .from_pretrained (model , pretrained_dir , is_trainable = True )
353+ return peft_model
354+
355+ def get_param_group_id (self , optimizer : Optimizer , origin_param : Parameter ):
356+ origin_param_id = id (origin_param )
357+ for group_id , param_group in enumerate (optimizer .param_groups ):
358+ for p in param_group ["params" ]:
359+ if id (p ) == origin_param_id :
360+ return group_id
361+ return - 1
362+
363+ def get_param_group_id (self , optimizer : Optimizer , origin_param : Parameter , lora_param : Parameter ):
364+ origin_param_id = id (origin_param )
365+ lora_param_id = id (lora_param )
366+ target_group_id = None
367+ for group_id , param_group in enumerate (optimizer .param_groups ):
368+ for p in param_group ["params" ]:
369+ if id (p ) == lora_param_id :
370+ # check if the lora parameter exists.
371+ return target_group_id , OptimizerParamCheckState .LORA_PARM_EXISTED
372+ if id (p ) == origin_param_id :
373+ target_group_id = group_id
374+ if target_group_id is not None :
375+ return target_group_id , OptimizerParamCheckState .ORIGIN_PARAM_FINDED
376+ else :
377+ return target_group_id , OptimizerParamCheckState .ORIGIN_PARAM_NOT_FIND
378+
379+ def add_lora_params_to_optimizer (self , model , optimizer ):
380+ """add lora parameters to optimizer"""
381+ name2param = {}
382+ for name , param in model .named_parameters ():
383+ name2param [name ] = param
384+
385+ for name , param in name2param .items ():
386+ if "lora_A" in name or "lora_B" in name :
387+ origin_key = name .replace ("lora_A." , "" )
388+ origin_key = origin_key .replace ("lora_B." , "" )
389+ origin_key = origin_key .replace (f"{ model .active_adapter } " , "base_layer" )
390+ origin_param = name2param [origin_key ]
391+ group_id , check_state = self .get_param_group_id (optimizer , origin_param , param )
392+ if check_state == OptimizerParamCheckState .ORIGIN_PARAM_NOT_FIND :
393+ warnings .warn (
394+ "Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
395+ )
396+ elif (
397+ check_state == OptimizerParamCheckState .ORIGIN_PARAM_FINDED
398+ and group_id is not None
399+ and group_id >= 0
400+ ):
401+ optimizer .param_groups [group_id ]["params" ].append (param )
402+
314403 def configure (
315404 self ,
316405 model : nn .Module ,
@@ -319,6 +408,15 @@ def configure(
319408 dataloader : Optional [DataLoader ] = None ,
320409 lr_scheduler : Optional [LRScheduler ] = None ,
321410 ) -> Tuple [nn .Module , OptimizerWrapper , Callable , DataLoader , LRScheduler ]:
411+ if self .lora_enabled :
412+ from peft import PeftModel
413+
414+ assert isinstance (
415+ model , PeftModel
416+ ), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
417+ if optimizer is not None :
418+ self .add_lora_params_to_optimizer (model , optimizer )
419+
322420 if not isinstance (model , ModelWrapper ):
323421 model = LowLevelZeroModel (model , self .precision )
324422
@@ -340,8 +438,3 @@ def get_checkpoint_io(self) -> CheckpointIO:
340438 def no_sync (self , model : nn .Module , optimizer : OptimizerWrapper ) -> Iterator [None ]:
341439 assert isinstance (optimizer , LowLevelZeroOptimizer )
342440 return optimizer .no_sync ()
343-
344- def enable_lora (
345- self , model : nn .Module , pretrained_dir : Optional [str ] = None , lora_config : Optional [Dict ] = None
346- ) -> nn .Module :
347- raise NotImplementedError
0 commit comments