11import logging
2+ import warnings
23import os
34from functools import partial
45from pathlib import Path
910
1011import torch
1112import torch .nn as nn
13+ from torch .nn import Parameter
1214from torch .optim import Optimizer
1315from torch .optim .lr_scheduler import _LRScheduler as LRScheduler
1416from torch .utils ._pytree import tree_map
@@ -335,13 +337,27 @@ def enable_lora(
335337 from peft import PeftModel , get_peft_model
336338 assert not isinstance (model , LowLevelZeroModel ), "Lora should be enabled before boosting the model."
337339 self .lora_enabled = True
340+ warnings .warn ("You have enabled LoRa training. Please check the hyperparameter such as lr" )
338341
339342 if pretrained_dir is None :
340343 peft_model = get_peft_model (model , lora_config )
341344 else :
342345 peft_model = PeftModel .from_pretrained (model , pretrained_dir , is_trainable = True )
343346 return peft_model
344347
348+ def get_param_group_id (self , optimizer : Optimizer , origin_param : Parameter , add_param : Parameter ):
349+ origin_param_id = id (origin_param )
350+ add_param_id = id (add_param )
351+ group_id = - 1
352+ for pg_id , param_group in enumerate (optimizer .param_groups ):
353+ for p in param_group ['params' ]:
354+ if id (p ) == add_param_id :
355+ return - 2
356+ if id (p ) == origin_param_id :
357+ group_id = pg_id
358+ return group_id
359+
360+
345361 def configure (
346362 self ,
347363 model : nn .Module ,
@@ -353,12 +369,21 @@ def configure(
353369 if self .lora_enabled :
354370 from peft import PeftModel
355371 assert isinstance (model , PeftModel ), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
356-
357- optim_params_nums = 0
358- for param_group in optimizer .param_groups :
359- optim_params_nums += len (param_group ['params' ])
360- model_params_nums = len (list (model .named_parameters ()))
361- assert optim_params_nums == model_params_nums , "Optimizer should be initialized after enabling lora."
372+
373+ # add lora parameters to optimizer
374+ name2param = {}
375+ for name , param in model .named_parameters ():
376+ name2param [name ] = param
377+ for name , param in name2param .items ():
378+ if 'lora_A' in name or 'lora_B' in name :
379+ origin_key = name .replace ("lora_A." , "" )
380+ origin_key = origin_key .replace ("lora_B." , "" )
381+ origin_key = origin_key .replace (f"{ model .active_adapter } ." , "" )
382+ origin_param = name2param [origin_key ]
383+ group_id = self .get_param_group_id (optimizer , origin_param , param )
384+ assert group_id != - 1 , "Parameter error, origin parameter does't exists."
385+ if group_id >= 0 :
386+ optimizer .param_groups [group_id ]['params' ].append (param )
362387
363388 if not isinstance (model , ModelWrapper ):
364389 model = LowLevelZeroModel (model , self .precision )
0 commit comments