1515import torch
1616import re
1717from library .utils import setup_logging
18+
1819setup_logging ()
1920import logging
21+
2022logger = logging .getLogger (__name__ )
2123
2224RE_UPDOWN = re .compile (r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_" )
@@ -504,6 +506,15 @@ def create_network(
504506 if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None :
505507 network .set_block_lr_weight (up_lr_weight , mid_lr_weight , down_lr_weight )
506508
509+ loraplus_lr_ratio = kwargs .get ("loraplus_lr_ratio" , None )
510+ loraplus_unet_lr_ratio = kwargs .get ("loraplus_unet_lr_ratio" , None )
511+ loraplus_text_encoder_lr_ratio = kwargs .get ("loraplus_text_encoder_lr_ratio" , None )
512+ loraplus_lr_ratio = float (loraplus_lr_ratio ) if loraplus_lr_ratio is not None else None
513+ loraplus_unet_lr_ratio = float (loraplus_unet_lr_ratio ) if loraplus_unet_lr_ratio is not None else None
514+ loraplus_text_encoder_lr_ratio = float (loraplus_text_encoder_lr_ratio ) if loraplus_text_encoder_lr_ratio is not None else None
515+ if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None :
516+ network .set_loraplus_lr_ratio (loraplus_lr_ratio , loraplus_unet_lr_ratio , loraplus_text_encoder_lr_ratio )
517+
507518 return network
508519
509520
@@ -529,7 +540,9 @@ def parse_floats(s):
529540 len (block_dims ) == num_total_blocks
530541 ), f"block_dims must have { num_total_blocks } elements / block_dimsは{ num_total_blocks } 個指定してください"
531542 else :
532- logger .warning (f"block_dims is not specified. all dims are set to { network_dim } / block_dimsが指定されていません。すべてのdimは{ network_dim } になります" )
543+ logger .warning (
544+ f"block_dims is not specified. all dims are set to { network_dim } / block_dimsが指定されていません。すべてのdimは{ network_dim } になります"
545+ )
533546 block_dims = [network_dim ] * num_total_blocks
534547
535548 if block_alphas is not None :
@@ -803,21 +816,31 @@ def __init__(
803816 self .rank_dropout = rank_dropout
804817 self .module_dropout = module_dropout
805818
819+ self .loraplus_lr_ratio = None
820+ self .loraplus_unet_lr_ratio = None
821+ self .loraplus_text_encoder_lr_ratio = None
822+
806823 if modules_dim is not None :
807824 logger .info (f"create LoRA network from weights" )
808825 elif block_dims is not None :
809826 logger .info (f"create LoRA network from block_dims" )
810- logger .info (f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } " )
827+ logger .info (
828+ f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } "
829+ )
811830 logger .info (f"block_dims: { block_dims } " )
812831 logger .info (f"block_alphas: { block_alphas } " )
813832 if conv_block_dims is not None :
814833 logger .info (f"conv_block_dims: { conv_block_dims } " )
815834 logger .info (f"conv_block_alphas: { conv_block_alphas } " )
816835 else :
817836 logger .info (f"create LoRA network. base dim (rank): { lora_dim } , alpha: { alpha } " )
818- logger .info (f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } " )
837+ logger .info (
838+ f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } "
839+ )
819840 if self .conv_lora_dim is not None :
820- logger .info (f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): { self .conv_lora_dim } , alpha: { self .conv_alpha } " )
841+ logger .info (
842+ f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): { self .conv_lora_dim } , alpha: { self .conv_alpha } "
843+ )
821844
822845 # create module instances
823846 def create_modules (
@@ -939,6 +962,11 @@ def create_modules(
939962 assert lora .lora_name not in names , f"duplicated lora name: { lora .lora_name } "
940963 names .add (lora .lora_name )
941964
965+ def set_loraplus_lr_ratio (self , loraplus_lr_ratio , loraplus_unet_lr_ratio , loraplus_text_encoder_lr_ratio ):
966+ self .loraplus_lr_ratio = loraplus_lr_ratio
967+ self .loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
968+ self .loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
969+
942970 def set_multiplier (self , multiplier ):
943971 self .multiplier = multiplier
944972 for lora in self .text_encoder_loras + self .unet_loras :
@@ -1033,15 +1061,7 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10331061 return lr_weight
10341062
10351063 # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036- def prepare_optimizer_params (
1037- self ,
1038- text_encoder_lr ,
1039- unet_lr ,
1040- default_lr ,
1041- text_encoder_loraplus_ratio = None ,
1042- unet_loraplus_ratio = None ,
1043- loraplus_ratio = None
1044- ):
1064+ def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr ):
10451065 self .requires_grad_ (True )
10461066 all_params = []
10471067
@@ -1078,7 +1098,7 @@ def assemble_params(loras, lr, ratio):
10781098 params = assemble_params (
10791099 self .text_encoder_loras ,
10801100 text_encoder_lr if text_encoder_lr is not None else default_lr ,
1081- text_encoder_loraplus_ratio or loraplus_ratio
1101+ self . loraplus_text_encoder_lr_ratio or self . loraplus_ratio ,
10821102 )
10831103 all_params .extend (params )
10841104
@@ -1097,15 +1117,15 @@ def assemble_params(loras, lr, ratio):
10971117 params = assemble_params (
10981118 block_loras ,
10991119 (unet_lr if unet_lr is not None else default_lr ) * self .get_lr_weight (block_loras [0 ]),
1100- unet_loraplus_ratio or loraplus_ratio
1120+ self . loraplus_unet_lr_ratio or self . loraplus_ratio ,
11011121 )
11021122 all_params .extend (params )
11031123
11041124 else :
11051125 params = assemble_params (
11061126 self .unet_loras ,
11071127 unet_lr if unet_lr is not None else default_lr ,
1108- unet_loraplus_ratio or loraplus_ratio
1128+ self . loraplus_unet_lr_ratio or self . loraplus_ratio ,
11091129 )
11101130 all_params .extend (params )
11111131
0 commit comments