From a6f1ed2e140eb4d4d37c0bb0502a7c0fd0621f5f Mon Sep 17 00:00:00 2001 From: tamlog06 Date: Sun, 18 Feb 2024 13:20:47 +0000 Subject: [PATCH] fix dylora create_modules error --- networks/dylora.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index e5a55d198..64e39eaf7 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -12,7 +12,9 @@ import math import os import random -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel import torch from torch import nn @@ -165,7 +167,15 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + **kwargs, +): if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -182,6 +192,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_alpha = 1.0 else: conv_alpha = float(conv_alpha) + if unit is not None: unit = int(unit) else: @@ -306,8 +317,22 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) loras.append(lora) return loras + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + self.text_encoder_loras = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}") + else: + index = None + print(f"create LoRA for Text Encoder") + + text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) - self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + # self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights