Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion modules/modelLoader/GenericEmbeddingModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


def make_embedding_model_loader(
Expand Down Expand Up @@ -32,6 +33,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> model_class | None:
base_model_loader = model_loader_class()
embedding_loader = embedding_loader_class()
Expand All @@ -41,7 +43,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
embedding_loader.load(model, model_names.embedding.model_name, model_names)

return model
Expand Down
4 changes: 3 additions & 1 deletion modules/modelLoader/GenericFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


def make_fine_tune_model_loader(
Expand Down Expand Up @@ -32,6 +33,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> model_class | None:
base_model_loader = model_loader_class()
if embedding_loader_class is not None:
Expand All @@ -42,7 +44,7 @@ def load(
self._load_internal_data(model, model_names.base_model)
model.model_spec = self._load_default_model_spec(model_type)

base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
if embedding_loader_class is not None:
embedding_loader.load(model, model_names.base_model, model_names)

Expand Down
4 changes: 3 additions & 1 deletion modules/modelLoader/GenericLoRAModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


def make_lora_model_loader(
Expand Down Expand Up @@ -33,6 +34,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
) -> model_class | None:
base_model_loader = model_loader_class()
lora_model_loader = lora_loader_class()
Expand All @@ -44,7 +46,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
lora_model_loader.load(model, model_names)
if embedding_loader_class is not None:
embedding_loader.load(model, model_names.lora, model_names)
Expand Down
16 changes: 11 additions & 5 deletions modules/modelLoader/chroma/ChromaModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter

import torch

Expand All @@ -33,10 +34,11 @@ def __load_internal(
base_model_name: str,
transformer_model_name: str,
vae_model_name: str,
quant_filters: list[ModuleFilter],
):
if os.path.isfile(os.path.join(base_model_name, "meta.json")):
self.__load_diffusers(
model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name,
model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, quant_filters,
)
else:
raise Exception("not an internal model")
Expand All @@ -49,6 +51,7 @@ def __load_diffusers(
base_model_name: str,
transformer_model_name: str,
vae_model_name: str,
quant_filters: list[ModuleFilter],
):
diffusers_sub = []
if not transformer_model_name:
Expand Down Expand Up @@ -104,7 +107,7 @@ def __load_diffusers(
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer == DataType.GGUF else None,
)
transformer = self._convert_diffusers_sub_module_to_dtype(
transformer, weight_dtypes.transformer, weight_dtypes.train_dtype
transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters,
)
else:
transformer = self._load_diffusers_sub_module(
Expand All @@ -113,6 +116,7 @@ def __load_diffusers(
weight_dtypes.train_dtype,
base_model_name,
"transformer",
quant_filters,
)

model.model_type = model_type
Expand All @@ -130,6 +134,7 @@ def __load_safetensors(
base_model_name: str,
transformer_model_name: str,
vae_model_name: str,
quant_filters: list[ModuleFilter],
):
#no single file .safetensors for Chroma available at the time of writing this code
raise NotImplementedError("Loading of single file Chroma models not supported. Use the diffusers model instead. Optionally, transformer-only safetensor files can be loaded by overriding the transformer.")
Expand All @@ -140,28 +145,29 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
):
stacktraces = []

try:
self.__load_internal(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model,
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters,
)
return
except Exception:
stacktraces.append(traceback.format_exc())

try:
self.__load_diffusers(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model,
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters,
)
return
except Exception:
stacktraces.append(traceback.format_exc())

try:
self.__load_safetensors(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model,
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters,
)
return
except Exception:
Expand Down
18 changes: 12 additions & 6 deletions modules/modelLoader/flux/FluxModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter

import torch

Expand Down Expand Up @@ -36,11 +37,12 @@ def __load_internal(
vae_model_name: str,
include_text_encoder_1: bool,
include_text_encoder_2: bool,
quant_filters: list[ModuleFilter],
):
if os.path.isfile(os.path.join(base_model_name, "meta.json")):
self.__load_diffusers(
model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name,
include_text_encoder_1, include_text_encoder_2,
include_text_encoder_1, include_text_encoder_2, quant_filters,
)
else:
raise Exception("not an internal model")
Expand All @@ -55,6 +57,7 @@ def __load_diffusers(
vae_model_name: str,
include_text_encoder_1: bool,
include_text_encoder_2: bool,
quant_filters: list[ModuleFilter],
):
diffusers_sub = []
transformers_sub = []
Expand Down Expand Up @@ -140,7 +143,7 @@ def __load_diffusers(
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer == DataType.GGUF else None,
)
transformer = self._convert_diffusers_sub_module_to_dtype(
transformer, weight_dtypes.transformer, weight_dtypes.train_dtype
transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters,
)
else:
transformer = self._load_diffusers_sub_module(
Expand All @@ -149,6 +152,7 @@ def __load_diffusers(
weight_dtypes.train_dtype,
base_model_name,
"transformer",
quant_filters,
)

model.model_type = model_type
Expand All @@ -170,6 +174,7 @@ def __load_safetensors(
vae_model_name: str,
include_text_encoder_1: bool,
include_text_encoder_2: bool,
quant_filters: list[ModuleFilter],
):
transformer = FluxTransformer2DModel.from_single_file(
#always load transformer separately even though FluxPipeLine.from_single_file() could load it, to avoid loading in float32:
Expand Down Expand Up @@ -222,7 +227,7 @@ def __load_safetensors(
print("text encoder 2 (t5) not loaded, continuing without it")

transformer = self._convert_diffusers_sub_module_to_dtype(
pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype
pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters,
)

model.model_type = model_type
Expand All @@ -240,13 +245,14 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
):
stacktraces = []

try:
self.__load_internal(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model,
model_names.include_text_encoder, model_names.include_text_encoder_2,
model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters,
)
return
except Exception:
Expand All @@ -255,7 +261,7 @@ def load(
try:
self.__load_diffusers(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model,
model_names.include_text_encoder, model_names.include_text_encoder_2,
model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters,
)
return
except Exception:
Expand All @@ -264,7 +270,7 @@ def load(
try:
self.__load_safetensors(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model,
model_names.include_text_encoder, model_names.include_text_encoder_2,
model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters,
)
return
except Exception:
Expand Down
16 changes: 11 additions & 5 deletions modules/modelLoader/hiDream/HiDreamModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter

from diffusers import (
AutoencoderKL,
Expand Down Expand Up @@ -42,11 +43,12 @@ def __load_internal(
include_text_encoder_2: bool,
include_text_encoder_3: bool,
include_text_encoder_4: bool,
quant_filters: list[ModuleFilter],
):
if os.path.isfile(os.path.join(base_model_name, "meta.json")):
self.__load_diffusers(
model, model_type, weight_dtypes, base_model_name, text_encoder_4_model_name, vae_model_name,
include_text_encoder_1, include_text_encoder_2, include_text_encoder_3, include_text_encoder_4,
include_text_encoder_1, include_text_encoder_2, include_text_encoder_3, include_text_encoder_4, quant_filters,
)
else:
raise Exception("not an internal model")
Expand All @@ -63,6 +65,7 @@ def __load_diffusers(
include_text_encoder_2: bool,
include_text_encoder_3: bool,
include_text_encoder_4: bool,
quant_filters: list[ModuleFilter],
):
diffusers_sub = []
transformers_sub = []
Expand Down Expand Up @@ -191,6 +194,7 @@ def __load_diffusers(
weight_dtypes.train_dtype,
base_model_name,
"transformer",
quant_filters,
)

model.model_type = model_type
Expand Down Expand Up @@ -218,6 +222,7 @@ def __load_safetensors(
include_text_encoder_2: bool,
include_text_encoder_3: bool,
include_text_encoder_4: bool,
quant_filters: list[ModuleFilter],
):
pipeline = HiDreamImagePipeline.from_single_file(
pretrained_model_link_or_path=base_model_name,
Expand Down Expand Up @@ -264,7 +269,7 @@ def __load_safetensors(
print("text encoder 2 (t5) not loaded, continuing without it")

transformer = self._convert_diffusers_sub_module_to_dtype(
pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype
pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters,
)

model.model_type = model_type
Expand All @@ -290,6 +295,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
):
stacktraces = []

Expand All @@ -298,7 +304,7 @@ def load(
model, model_type, weight_dtypes, model_names.base_model,
model_names.text_encoder_4, model_names.vae_model,
model_names.include_text_encoder, model_names.include_text_encoder_2,
model_names.include_text_encoder_3, model_names.include_text_encoder_4,
model_names.include_text_encoder_3, model_names.include_text_encoder_4, quant_filters,
)
self.__after_load(model)
return
Expand All @@ -310,7 +316,7 @@ def load(
model, model_type, weight_dtypes, model_names.base_model,
model_names.text_encoder_4, model_names.vae_model,
model_names.include_text_encoder, model_names.include_text_encoder_2,
model_names.include_text_encoder_3, model_names.include_text_encoder_4,
model_names.include_text_encoder_3, model_names.include_text_encoder_4, quant_filters,
)
self.__after_load(model)
return
Expand All @@ -322,7 +328,7 @@ def load(
model, model_type, weight_dtypes, model_names.base_model,
model_names.text_encoder_4, model_names.vae_model,
model_names.include_text_encoder, model_names.include_text_encoder_2,
model_names.include_text_encoder_3, model_names.include_text_encoder_4,
model_names.include_text_encoder_3, model_names.include_text_encoder_4, quant_filters,
)
self.__after_load(model)
return
Expand Down
Loading