11import logging
22import warnings
3+ import enum
34import os
45from functools import partial
56from pathlib import Path
@@ -43,6 +44,11 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
4344
4445SUPPORTED_PRECISION = ["fp16" , "bf16" , "fp32" ]
4546
47+ class OptimizerParamCheckState (enum .Enum ):
48+ ORIGIN_PARAM_FINDED = 0
49+ ORIGIN_PARAM_NOT_FIND = - 1
50+ LORA_PARM_EXISTED = - 2
51+
4652
4753class LowLevelZeroModel (ModelWrapper , AMPModelMixin ):
4854 def __init__ (self , module : nn .Module , precision : str ) -> None :
@@ -354,15 +360,18 @@ def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
354360 def get_param_group_id (self , optimizer : Optimizer , origin_param : Parameter , lora_param : Parameter ):
355361 origin_param_id = id (origin_param )
356362 lora_param_id = id (lora_param )
357- target_group_id = - 1
363+ target_group_id = None
358364 for group_id , param_group in enumerate (optimizer .param_groups ):
359365 for p in param_group ['params' ]:
360366 if id (p ) == lora_param_id :
361367 # check if the lora parameter exists.
362- return - 2
368+ return target_group_id , OptimizerParamCheckState . LORA_PARM_EXISTED
363369 if id (p ) == origin_param_id :
364370 target_group_id = group_id
365- return target_group_id
371+ if target_group_id is not None :
372+ return target_group_id , OptimizerParamCheckState .ORIGIN_PARAM_FINDED
373+ else :
374+ return target_group_id , OptimizerParamCheckState .ORIGIN_PARAM_NOT_FIND
366375
367376 def add_lora_params_to_optimizer (self , model , optimizer ):
368377 """ add lora parameters to optimizer """
@@ -374,12 +383,12 @@ def add_lora_params_to_optimizer(self, model, optimizer):
374383 if 'lora_A' in name or 'lora_B' in name :
375384 origin_key = name .replace ("lora_A." , "" )
376385 origin_key = origin_key .replace ("lora_B." , "" )
377- origin_key = origin_key .replace (f"{ model .active_adapter } . " , "" )
386+ origin_key = origin_key .replace (f"{ model .active_adapter } " , "base_layer " )
378387 origin_param = name2param [origin_key ]
379- group_id = self .get_param_group_id (optimizer , origin_param , param )
380- if group_id == - 1 :
388+ group_id , check_state = self .get_param_group_id (optimizer , origin_param , param )
389+ if check_state == OptimizerParamCheckState . ORIGIN_PARAM_NOT_FIND :
381390 warnings .warn ("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." )
382- elif group_id >= 0 :
391+ elif check_state == OptimizerParamCheckState . ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0 :
383392 optimizer .param_groups [group_id ]['params' ].append (param )
384393
385394 def configure (
0 commit comments