@@ -1033,7 +1033,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10331033 return lr_weight
10341034
10351035 # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036- def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr , , unet_lora_plus_ratio = None , text_encoder_lora_plus_ratio = None ):
1036+ def prepare_optimizer_params (
1037+ self ,
1038+ text_encoder_lr ,
1039+ unet_lr ,
1040+ default_lr ,
1041+ unet_lora_plus_ratio = None ,
1042+ text_encoder_lora_plus_ratio = None
1043+ ):
10371044 self .requires_grad_ (True )
10381045 all_params = []
10391046
@@ -1068,7 +1075,11 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
10681075 return params
10691076
10701077 if self .text_encoder_loras :
1071- params = assemble_params (self .text_encoder_loras , text_encoder_lr , text_encoder_lora_plus_ratio )
1078+ params = assemble_params (
1079+ self .text_encoder_loras ,
1080+ text_encoder_lr if text_encoder_lr is not None else default_lr ,
1081+ text_encoder_lora_plus_ratio
1082+ )
10721083 all_params .extend (params )
10731084
10741085 if self .unet_loras :
@@ -1083,14 +1094,19 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
10831094
10841095 # blockごとにパラメータを設定する
10851096 for idx , block_loras in block_idx_to_lora .items ():
1086- if unet_lr is not None :
1087- params = assemble_params (block_loras , unet_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
1088- elif default_lr is not None :
1089- params = assemble_params (block_loras , default_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
1097+ params = assemble_params (
1098+ block_loras ,
1099+ (unet_lr if unet_lr is not None else default_lr ) * self .get_lr_weight (block_loras [0 ]),
1100+ unet_lora_plus_ratio
1101+ )
10901102 all_params .extend (params )
10911103
10921104 else :
1093- params = assemble_params (self .unet_loras , unet_lr , unet_lora_plus_ratio )
1105+ params = assemble_params (
1106+ self .unet_loras ,
1107+ default_lr if unet_lr is None else unet_lr ,
1108+ unet_lora_plus_ratio
1109+ )
10941110 all_params .extend (params )
10951111
10961112 return all_params
0 commit comments