@@ -351,26 +351,35 @@ def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
351351 return group_id
352352 return - 1
353353
354+ def get_param_group_id (self , optimizer : Optimizer , origin_param : Parameter , lora_param : Parameter ):
355+ origin_param_id = id (origin_param )
356+ lora_param_id = id (lora_param )
357+ target_group_id = - 1
358+ for group_id , param_group in enumerate (optimizer .param_groups ):
359+ for p in param_group ['params' ]:
360+ if id (p ) == lora_param_id :
361+ # check if the lora parameter exists.
362+ return - 2
363+ if id (p ) == origin_param_id :
364+ target_group_id = group_id
365+ return target_group_id
366+
354367 def add_lora_params_to_optimizer (self , model , optimizer ):
355368 """ add lora parameters to optimizer """
356369 name2param = {}
357370 for name , param in model .named_parameters ():
358371 name2param [name ] = param
359372
360- optimizer_param_nums = 0
361- for param_group in optimizer .param_groups :
362- optimizer_param_nums += len (param_group ['params' ])
363-
364- # Check if the optimizer is created after the model is transformed into a LoRa model.
365- if len (name2param ) != optimizer_param_nums :
366- for name , param in name2param .items ():
367- if 'lora_A' in name or 'lora_B' in name :
368- origin_key = name .replace ("lora_A." , "" )
369- origin_key = origin_key .replace ("lora_B." , "" )
370- origin_key = origin_key .replace (f"{ model .active_adapter } ." , "" )
371- origin_param = name2param [origin_key ]
372- group_id = self .get_param_group_id (optimizer , origin_param )
373- assert group_id != - 1 , "Parameter error, origin parameter does't exists."
373+ for name , param in name2param .items ():
374+ if 'lora_A' in name or 'lora_B' in name :
375+ origin_key = name .replace ("lora_A." , "" )
376+ origin_key = origin_key .replace ("lora_B." , "" )
377+ origin_key = origin_key .replace (f"{ model .active_adapter } ." , "" )
378+ origin_param = name2param [origin_key ]
379+ group_id = self .get_param_group_id (optimizer , origin_param , param )
380+ if group_id == - 1 :
381+ warnings .warn ("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." )
382+ elif group_id >= 0 :
374383 optimizer .param_groups [group_id ]['params' ].append (param )
375384
376385 def configure (
@@ -384,7 +393,8 @@ def configure(
384393 if self .lora_enabled :
385394 from peft import PeftModel
386395 assert isinstance (model , PeftModel ), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
387- self .add_lora_params_to_optimizer (model , optimizer )
396+ if optimizer is not None :
397+ self .add_lora_params_to_optimizer (model , optimizer )
388398
389399
390400 if not isinstance (model , ModelWrapper ):
0 commit comments