@@ -1033,22 +1033,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10331033 return lr_weight
10341034
10351035 # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036- def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr ):
1036+ def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr , , unet_lora_plus_ratio = None , text_encoder_lora_plus_ratio = None ):
10371037 self .requires_grad_ (True )
10381038 all_params = []
10391039
1040- def enumerate_params (loras : List [LoRAModule ]):
1041- params = []
1040+ def assemble_params (loras : List [LoRAModule ], lr , lora_plus_ratio ):
1041+ param_groups = { "lora" : {}, "plus" : {}}
10421042 for lora in loras :
1043- # params.extend(lora.parameters())
1044- params .extend (lora .get_trainable_params ())
1043+ for name , param in lora .get_trainable_named_params ():
1044+ if lora_plus_ratio is not None and "lora_up" in name :
1045+ param_groups ["plus" ][f"{ lora .lora_name } .{ name } " ] = param
1046+ else :
1047+ param_groups ["lora" ][f"{ lora .lora_name } .{ name } " ] = param
1048+
1049+ # assigned_param_groups = ""
1050+ # for group in param_groups:
1051+ # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
1052+ # logger.info(assigned_param_groups)
1053+
1054+ params = []
1055+ for key in param_groups .keys ():
1056+ param_data = {"params" : param_groups [key ].values ()}
1057+ if lr is not None :
1058+ if key == "plus" :
1059+ param_data ["lr" ] = lr * lora_plus_ratio
1060+ else :
1061+ param_data ["lr" ] = lr
1062+
1063+ if ("lr" in param_data ) and (param_data ["lr" ] == 0 ):
1064+ continue
1065+
1066+ params .append (param_data )
1067+
10451068 return params
10461069
10471070 if self .text_encoder_loras :
1048- param_data = {"params" : enumerate_params (self .text_encoder_loras )}
1049- if text_encoder_lr is not None :
1050- param_data ["lr" ] = text_encoder_lr
1051- all_params .append (param_data )
1071+ params = assemble_params (self .text_encoder_loras , text_encoder_lr , text_encoder_lora_plus_ratio )
1072+ all_params .extend (params )
10521073
10531074 if self .unet_loras :
10541075 if self .block_lr :
@@ -1062,21 +1083,15 @@ def enumerate_params(loras: List[LoRAModule]):
10621083
10631084 # blockごとにパラメータを設定する
10641085 for idx , block_loras in block_idx_to_lora .items ():
1065- param_data = {"params" : enumerate_params (block_loras )}
1066-
10671086 if unet_lr is not None :
1068- param_data [ "lr" ] = unet_lr * self .get_lr_weight (block_loras [0 ])
1087+ params = assemble_params ( block_loras , unet_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
10691088 elif default_lr is not None :
1070- param_data ["lr" ] = default_lr * self .get_lr_weight (block_loras [0 ])
1071- if ("lr" in param_data ) and (param_data ["lr" ] == 0 ):
1072- continue
1073- all_params .append (param_data )
1089+ params = assemble_params (block_loras , default_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
1090+ all_params .extend (params )
10741091
10751092 else :
1076- param_data = {"params" : enumerate_params (self .unet_loras )}
1077- if unet_lr is not None :
1078- param_data ["lr" ] = unet_lr
1079- all_params .append (param_data )
1093+ params = assemble_params (self .unet_loras , unet_lr , unet_lora_plus_ratio )
1094+ all_params .extend (params )
10801095
10811096 return all_params
10821097
@@ -1093,6 +1108,9 @@ def on_epoch_start(self, text_encoder, unet):
10931108 def get_trainable_params (self ):
10941109 return self .parameters ()
10951110
1111+ def get_trainable_named_params (self ):
1112+ return self .named_parameters ()
1113+
10961114 def save_weights (self , file , dtype , metadata ):
10971115 if metadata is not None and len (metadata ) == 0 :
10981116 metadata = None
0 commit comments