@@ -1035,21 +1035,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10351035 return lr_weight
10361036
10371037 # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1038- def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr ):
1038+ def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr , unet_lora_plus_ratio = None , text_encoder_lora_plus_ratio = None ):
10391039 self .requires_grad_ (True )
10401040 all_params = []
10411041
1042- def enumerate_params (loras ):
1043- params = []
1042+ def assemble_params (loras , lr , lora_plus_ratio ):
1043+ param_groups = { "lora" : {}, "plus" : {}}
10441044 for lora in loras :
1045- params .extend (lora .parameters ())
1045+ for name , param in lora .named_parameters ():
1046+ if lora_plus_ratio is not None and "lora_up" in name :
1047+ param_groups ["plus" ][f"{ lora .lora_name } .{ name } " ] = param
1048+ else :
1049+ param_groups ["lora" ][f"{ lora .lora_name } .{ name } " ] = param
1050+
1051+ # assigned_param_groups = ""
1052+ # for group in param_groups:
1053+ # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
1054+ # logger.info(assigned_param_groups)
1055+
1056+ params = []
1057+ for key in param_groups .keys ():
1058+ param_data = {"params" : param_groups [key ].values ()}
1059+ if lr is not None :
1060+ if key == "plus" :
1061+ param_data ["lr" ] = lr * lora_plus_ratio
1062+ else :
1063+ param_data ["lr" ] = lr
1064+
1065+ if ("lr" in param_data ) and (param_data ["lr" ] == 0 ):
1066+ continue
1067+
1068+ params .append (param_data )
1069+
10461070 return params
10471071
10481072 if self .text_encoder_loras :
1049- param_data = {"params" : enumerate_params (self .text_encoder_loras )}
1050- if text_encoder_lr is not None :
1051- param_data ["lr" ] = text_encoder_lr
1052- all_params .append (param_data )
1073+ params = assemble_params (self .text_encoder_loras , text_encoder_lr , text_encoder_lora_plus_ratio )
1074+ all_params .extend (params )
10531075
10541076 if self .unet_loras :
10551077 if self .block_lr :
@@ -1063,21 +1085,15 @@ def enumerate_params(loras):
10631085
10641086 # blockごとにパラメータを設定する
10651087 for idx , block_loras in block_idx_to_lora .items ():
1066- param_data = {"params" : enumerate_params (block_loras )}
1067-
10681088 if unet_lr is not None :
1069- param_data [ "lr" ] = unet_lr * self .get_lr_weight (block_loras [0 ])
1089+ params = assemble_params ( block_loras , unet_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
10701090 elif default_lr is not None :
1071- param_data ["lr" ] = default_lr * self .get_lr_weight (block_loras [0 ])
1072- if ("lr" in param_data ) and (param_data ["lr" ] == 0 ):
1073- continue
1074- all_params .append (param_data )
1091+ params = assemble_params (block_loras , default_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
1092+ all_params .extend (params )
10751093
10761094 else :
1077- param_data = {"params" : enumerate_params (self .unet_loras )}
1078- if unet_lr is not None :
1079- param_data ["lr" ] = unet_lr
1080- all_params .append (param_data )
1095+ params = assemble_params (self .unet_loras , unet_lr , unet_lora_plus_ratio )
1096+ all_params .extend (params )
10811097
10821098 return all_params
10831099
0 commit comments