Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 7 additions & 42 deletions modules/modelLoader/ChromaEmbeddingModelLoader.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,12 @@
from modules.model.ChromaModel import ChromaModel
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.chroma.ChromaEmbeddingLoader import ChromaEmbeddingLoader
from modules.modelLoader.chroma.ChromaModelLoader import ChromaModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes


class ChromaEmbeddingModelLoader(
BaseModelLoader,
ModelSpecModelLoaderMixin,
InternalModelLoaderMixin,
):
def __init__(self):
super().__init__()

def _default_model_spec_name(
self,
model_type: ModelType,
) -> str | None:
match model_type:
case ModelType.CHROMA_1:
return "resources/sd_model_spec/chroma-embedding.json"
case _:
return None

def load(
self,
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
) -> ChromaModel | None:
base_model_loader = ChromaModelLoader()
embedding_loader = ChromaEmbeddingLoader()

model = ChromaModel(model_type=model_type)
self._load_internal_data(model, model_names.embedding.model_name)
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
embedding_loader.load(model, model_names.embedding.model_name, model_names)

return model
ChromaEmbeddingModelLoader = make_embedding_model_loader(
model_spec_map={ModelType.CHROMA_1: "resources/sd_model_spec/chroma-embedding.json"},
model_class=ChromaModel,
model_loader_class=ChromaModelLoader,
embedding_loader_class=ChromaEmbeddingLoader,
)
49 changes: 7 additions & 42 deletions modules/modelLoader/ChromaFineTuneModelLoader.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,12 @@
from modules.model.ChromaModel import ChromaModel
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.chroma.ChromaEmbeddingLoader import ChromaEmbeddingLoader
from modules.modelLoader.chroma.ChromaModelLoader import ChromaModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes


class ChromaFineTuneModelLoader(
BaseModelLoader,
ModelSpecModelLoaderMixin,
InternalModelLoaderMixin,
):
def __init__(self):
super().__init__()

def _default_model_spec_name(
self,
model_type: ModelType,
) -> str | None:
match model_type:
case ModelType.CHROMA_1:
return "resources/sd_model_spec/chroma.json"
case _:
return None

def load(
self,
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
) -> ChromaModel | None:
base_model_loader = ChromaModelLoader()
embedding_loader = ChromaEmbeddingLoader()

model = ChromaModel(model_type=model_type)

self._load_internal_data(model, model_names.base_model)
model.model_spec = self._load_default_model_spec(model_type)

base_model_loader.load(model, model_type, model_names, weight_dtypes)
embedding_loader.load(model, model_names.base_model, model_names)

return model
ChromaFineTuneModelLoader = make_fine_tune_model_loader(
model_spec_map={ModelType.CHROMA_1: "resources/sd_model_spec/chroma.json"},
model_class=ChromaModel,
model_loader_class=ChromaModelLoader,
embedding_loader_class=ChromaEmbeddingLoader,
)
52 changes: 8 additions & 44 deletions modules/modelLoader/ChromaLoRAModelLoader.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,14 @@
from modules.model.ChromaModel import ChromaModel
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.chroma.ChromaEmbeddingLoader import ChromaEmbeddingLoader
from modules.modelLoader.chroma.ChromaLoRALoader import ChromaLoRALoader
from modules.modelLoader.chroma.ChromaModelLoader import ChromaModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes


class ChromaLoRAModelLoader(
BaseModelLoader,
ModelSpecModelLoaderMixin,
InternalModelLoaderMixin,
):
def __init__(self):
super().__init__()

def _default_model_spec_name(
self,
model_type: ModelType,
) -> str | None:
match model_type:
case ModelType.CHROMA_1:
return "resources/sd_model_spec/chroma-lora.json"
case _:
return None

def load(
self,
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
) -> ChromaModel | None:
base_model_loader = ChromaModelLoader()
lora_model_loader = ChromaLoRALoader()
embedding_loader = ChromaEmbeddingLoader()

model = ChromaModel(model_type=model_type)
self._load_internal_data(model, model_names.lora)
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
lora_model_loader.load(model, model_names)
embedding_loader.load(model, model_names.lora, model_names)

return model
ChromaLoRAModelLoader = make_lora_model_loader(
model_spec_map={ModelType.CHROMA_1: "resources/sd_model_spec/chroma-lora.json"},
model_class=ChromaModel,
model_loader_class=ChromaModelLoader,
embedding_loader_class=ChromaEmbeddingLoader,
lora_loader_class=ChromaLoRALoader,
)
54 changes: 10 additions & 44 deletions modules/modelLoader/FluxEmbeddingModelLoader.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,15 @@
from modules.model.FluxModel import FluxModel
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.flux.FluxEmbeddingLoader import FluxEmbeddingLoader
from modules.modelLoader.flux.FluxModelLoader import FluxModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.modelLoader.GenericEmbeddingModelLoader import make_embedding_model_loader
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes


class FluxEmbeddingModelLoader(
BaseModelLoader,
ModelSpecModelLoaderMixin,
InternalModelLoaderMixin,
):
def __init__(self):
super().__init__()

def _default_model_spec_name(
self,
model_type: ModelType,
) -> str | None:
match model_type:
case ModelType.FLUX_DEV_1:
return "resources/sd_model_spec/flux_dev_1.0-embedding.json"
case ModelType.FLUX_FILL_DEV_1:
return "resources/sd_model_spec/flux_dev_fill_1.0-embedding.json"
case _:
return None

def load(
self,
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
) -> FluxModel | None:
base_model_loader = FluxModelLoader()
embedding_loader = FluxEmbeddingLoader()

model = FluxModel(model_type=model_type)
self._load_internal_data(model, model_names.embedding.model_name)
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
embedding_loader.load(model, model_names.embedding.model_name, model_names)

return model
FluxEmbeddingModelLoader = make_embedding_model_loader(
model_spec_map={
ModelType.FLUX_DEV_1: "resources/sd_model_spec/flux_dev_1.0-embedding.json",
ModelType.FLUX_FILL_DEV_1: "resources/sd_model_spec/flux_dev_fill_1.0-embedding.json",
},
model_class=FluxModel,
model_loader_class=FluxModelLoader,
embedding_loader_class=FluxEmbeddingLoader,
)
54 changes: 10 additions & 44 deletions modules/modelLoader/FluxFineTuneModelLoader.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,15 @@
from modules.model.FluxModel import FluxModel
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.flux.FluxEmbeddingLoader import FluxEmbeddingLoader
from modules.modelLoader.flux.FluxModelLoader import FluxModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.modelLoader.GenericFineTuneModelLoader import make_fine_tune_model_loader
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes


class FluxFineTuneModelLoader(
BaseModelLoader,
ModelSpecModelLoaderMixin,
InternalModelLoaderMixin,
):
def __init__(self):
super().__init__()

def _default_model_spec_name(
self,
model_type: ModelType,
) -> str | None:
match model_type:
case ModelType.FLUX_DEV_1:
return "resources/sd_model_spec/flux_dev_1.0.json"
case ModelType.FLUX_FILL_DEV_1:
return "resources/sd_model_spec/flux_dev_fill_1.0.json"
case _:
return None

def load(
self,
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
) -> FluxModel | None:
base_model_loader = FluxModelLoader()
embedding_loader = FluxEmbeddingLoader()

model = FluxModel(model_type=model_type)

self._load_internal_data(model, model_names.base_model)
model.model_spec = self._load_default_model_spec(model_type)

base_model_loader.load(model, model_type, model_names, weight_dtypes)
embedding_loader.load(model, model_names.base_model, model_names)

return model
FluxFineTuneModelLoader = make_fine_tune_model_loader(
model_spec_map={
ModelType.FLUX_DEV_1: "resources/sd_model_spec/flux_dev_1.0.json",
ModelType.FLUX_FILL_DEV_1: "resources/sd_model_spec/flux_dev_fill_1.0.json",
},
model_class=FluxModel,
model_loader_class=FluxModelLoader,
embedding_loader_class=FluxEmbeddingLoader,
)
57 changes: 11 additions & 46 deletions modules/modelLoader/FluxLoRAModelLoader.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,17 @@
from modules.model.FluxModel import FluxModel
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.flux.FluxEmbeddingLoader import FluxEmbeddingLoader
from modules.modelLoader.flux.FluxLoRALoader import FluxLoRALoader
from modules.modelLoader.flux.FluxModelLoader import FluxModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.modelLoader.GenericLoRAModelLoader import make_lora_model_loader
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes


class FluxLoRAModelLoader(
BaseModelLoader,
ModelSpecModelLoaderMixin,
InternalModelLoaderMixin,
):
def __init__(self):
super().__init__()

def _default_model_spec_name(
self,
model_type: ModelType,
) -> str | None:
match model_type:
case ModelType.FLUX_DEV_1:
return "resources/sd_model_spec/flux_dev_1.0-lora.json"
case ModelType.FLUX_FILL_DEV_1:
return "resources/sd_model_spec/flux_dev_fill_1.0-lora.json"
case _:
return None

def load(
self,
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
) -> FluxModel | None:
base_model_loader = FluxModelLoader()
lora_model_loader = FluxLoRALoader()
embedding_loader = FluxEmbeddingLoader()

model = FluxModel(model_type=model_type)
self._load_internal_data(model, model_names.lora)
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes)
lora_model_loader.load(model, model_names)
embedding_loader.load(model, model_names.lora, model_names)

return model
FluxLoRAModelLoader = make_lora_model_loader(
model_spec_map={
ModelType.FLUX_DEV_1: "resources/sd_model_spec/flux_dev_1.0-lora.json",
ModelType.FLUX_FILL_DEV_1: "resources/sd_model_spec/flux_dev_fill_1.0-lora.json",
},
model_class=FluxModel,
model_loader_class=FluxModelLoader,
embedding_loader_class=FluxEmbeddingLoader,
lora_loader_class=FluxLoRALoader,
)
Loading