Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions modules/dataLoader/ChromaBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from modules.dataLoader.BaseDataLoader import BaseDataLoader
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.ChromaModel import ChromaModel
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -251,3 +253,5 @@ def create_dataset(
train_progress,
is_validation
)

factory.register(BaseDataLoader, ChromaBaseDataLoader, ModelType.CHROMA_1)
5 changes: 5 additions & 0 deletions modules/dataLoader/FluxBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from modules.dataLoader.flux.ShuffleFluxFillMaskChannels import ShuffleFluxFillMaskChannels
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.FluxModel import FluxModel
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -296,3 +298,6 @@ def create_dataset(
train_progress,
is_validation
)

factory.register(BaseDataLoader, FluxBaseDataLoader, ModelType.FLUX_DEV_1)
factory.register(BaseDataLoader, FluxBaseDataLoader, ModelType.FLUX_FILL_DEV_1)
4 changes: 4 additions & 0 deletions modules/dataLoader/HiDreamBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from modules.dataLoader.flux.ShuffleFluxFillMaskChannels import ShuffleFluxFillMaskChannels
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.HiDreamModel import HiDreamModel
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -341,3 +343,5 @@ def create_dataset(
train_progress,
is_validation
)

factory.register(BaseDataLoader, HiDreamBaseDataLoader, ModelType.HI_DREAM_FULL)
4 changes: 4 additions & 0 deletions modules/dataLoader/HunyuanVideoBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
DEFAULT_PROMPT_TEMPLATE_CROP_START,
HunyuanVideoModel,
)
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -291,3 +293,5 @@ def create_dataset(
train_progress,
is_validation
)

factory.register(BaseDataLoader, HunyuanVideoBaseDataLoader, ModelType.HUNYUAN_VIDEO)
5 changes: 5 additions & 0 deletions modules/dataLoader/PixArtAlphaBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from modules.dataLoader.BaseDataLoader import BaseDataLoader
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.PixArtAlphaModel import PixArtAlphaModel
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -261,3 +263,6 @@ def create_dataset(
train_progress,
is_validation,
)

factory.register(BaseDataLoader, PixArtAlphaBaseDataLoader, ModelType.PIXART_ALPHA)
factory.register(BaseDataLoader, PixArtAlphaBaseDataLoader, ModelType.PIXART_SIGMA)
4 changes: 4 additions & 0 deletions modules/dataLoader/QwenBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
PROMPT_MAX_LENGTH,
QwenModel,
)
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -258,3 +260,5 @@ def create_dataset(
train_progress,
is_validation
)

factory.register(BaseDataLoader, QwenBaseDataLoader, ModelType.QWEN)
4 changes: 4 additions & 0 deletions modules/dataLoader/SanaBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from modules.dataLoader.BaseDataLoader import BaseDataLoader
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.SanaModel import SanaModel
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -254,3 +256,5 @@ def create_dataset(
train_progress,
is_validation,
)

factory.register(BaseDataLoader, SanaBaseDataLoader, ModelType.SANA)
4 changes: 4 additions & 0 deletions modules/dataLoader/StableDiffusion3BaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from modules.dataLoader.BaseDataLoader import BaseDataLoader
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.StableDiffusion3Model import StableDiffusion3Model
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -319,3 +321,5 @@ def create_dataset(
train_progress,
is_validation,
)

factory.register(BaseDataLoader, StableDiffusion3BaseDataLoader, ModelType.STABLE_DIFFUSION_35)
11 changes: 11 additions & 0 deletions modules/dataLoader/StableDiffusionBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from modules.dataLoader.BaseDataLoader import BaseDataLoader
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.StableDiffusionModel import StableDiffusionModel
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -267,3 +269,12 @@ def create_dataset(
train_progress,
is_validation,
)

factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_15)
factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_15_INPAINTING)
factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_20)
factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_20_BASE)
factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_20_INPAINTING)
factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_20_DEPTH)
factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_21)
factory.register(BaseDataLoader, StableDiffusionBaseDataLoader, ModelType.STABLE_DIFFUSION_21_BASE)
13 changes: 12 additions & 1 deletion modules/dataLoader/StableDiffusionFineTuneVaeDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from modules.dataLoader.BaseDataLoader import BaseDataLoader
from modules.model.StableDiffusionModel import StableDiffusionModel
from modules.util import path_util
from modules.util import factory, path_util
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -326,3 +328,12 @@ def create_dataset(
train_progress,
is_validation,
)

factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_15, TrainingMethod.FINE_TUNE_VAE)
factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_15_INPAINTING, TrainingMethod.FINE_TUNE_VAE)
factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_20, TrainingMethod.FINE_TUNE_VAE)
factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_20_BASE, TrainingMethod.FINE_TUNE_VAE)
factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_20_INPAINTING, TrainingMethod.FINE_TUNE_VAE)
factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_20_DEPTH, TrainingMethod.FINE_TUNE_VAE)
factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_21, TrainingMethod.FINE_TUNE_VAE)
factory.register(BaseDataLoader, StableDiffusionFineTuneVaeDataLoader, ModelType.STABLE_DIFFUSION_21_BASE, TrainingMethod.FINE_TUNE_VAE)
4 changes: 4 additions & 0 deletions modules/dataLoader/StableDiffusionXLBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from modules.dataLoader.BaseDataLoader import BaseDataLoader
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.StableDiffusionXLModel import StableDiffusionXLModel
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -287,3 +289,5 @@ def create_dataset(
train_progress,
is_validation,
)
factory.register(BaseDataLoader, StableDiffusionXLBaseDataLoader, ModelType.STABLE_DIFFUSION_XL_10_BASE)
factory.register(BaseDataLoader, StableDiffusionXLBaseDataLoader, ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING)
5 changes: 5 additions & 0 deletions modules/dataLoader/WuerstchenBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.dataLoader.wuerstchen.EncodeWuerstchenEffnet import EncodeWuerstchenEffnet
from modules.model.WuerstchenModel import WuerstchenModel
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -237,3 +239,6 @@ def create_dataset(
train_progress,
is_validation,
)

factory.register(BaseDataLoader, WuerstchenBaseDataLoader, ModelType.WUERSTCHEN_2)
factory.register(BaseDataLoader, WuerstchenBaseDataLoader, ModelType.STABLE_CASCADE_1)
4 changes: 4 additions & 0 deletions modules/dataLoader/ZImageBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from modules.dataLoader.BaseDataLoader import BaseDataLoader
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.ZImageModel import PROMPT_MAX_LENGTH, ZImageModel, format_input
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

Expand Down Expand Up @@ -242,3 +244,5 @@ def create_dataset(
train_progress,
is_validation
)

factory.register(BaseDataLoader, ZImageBaseDataLoader, ModelType.Z_IMAGE)
5 changes: 5 additions & 0 deletions modules/modelLoader/GenericEmbeddingModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.util import factory
from modules.util.config.TrainConfig import QuantizationConfig
from modules.util.enum.ModelType import ModelType
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes

Expand Down Expand Up @@ -47,4 +49,7 @@ def load(
embedding_loader.load(model, model_names.embedding.model_name, model_names)

return model

for model_type in model_spec_map:
factory.register(BaseModelLoader, GenericEmbeddingModelLoader, model_type, TrainingMethod.EMBEDDING)
return GenericEmbeddingModelLoader
9 changes: 9 additions & 0 deletions modules/modelLoader/GenericFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.util import factory
from modules.util.config.TrainConfig import QuantizationConfig
from modules.util.enum.ModelType import ModelType
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes

Expand All @@ -13,7 +15,11 @@ def make_fine_tune_model_loader(
model_class: type[BaseModel],
model_loader_class: type,
embedding_loader_class: type | None,
training_methods: list[TrainingMethod] = None,
):
if training_methods is None:
training_methods = [TrainingMethod.FINE_TUNE]

class GenericFineTuneModelLoader(
BaseModelLoader,
ModelSpecModelLoaderMixin,
Expand Down Expand Up @@ -50,4 +56,7 @@ def load(

return model

for model_type in model_spec_map:
for method in training_methods:
factory.register(BaseModelLoader, GenericFineTuneModelLoader, model_type, method)
return GenericFineTuneModelLoader
4 changes: 4 additions & 0 deletions modules/modelLoader/GenericLoRAModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.util import factory
from modules.util.config.TrainConfig import QuantizationConfig
from modules.util.enum.ModelType import ModelType
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes

Expand Down Expand Up @@ -53,4 +55,6 @@ def load(

return model

for model_type in model_spec_map:
factory.register(BaseModelLoader, GenericLoRAModelLoader, model_type, TrainingMethod.LORA)
return GenericLoRAModelLoader
2 changes: 2 additions & 0 deletions modules/modelLoader/StableDiffusionFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from modules.modelLoader.stableDiffusion.StableDiffusionEmbeddingLoader import StableDiffusionEmbeddingLoader
from modules.modelLoader.stableDiffusion.StableDiffusionModelLoader import StableDiffusionModelLoader
from modules.util.enum.ModelType import ModelType
from modules.util.enum.TrainingMethod import TrainingMethod

StableDiffusionFineTuneModelLoader = make_fine_tune_model_loader(
model_spec_map={
Expand All @@ -18,4 +19,5 @@
model_class=StableDiffusionModel,
model_loader_class=StableDiffusionModelLoader,
embedding_loader_class=StableDiffusionEmbeddingLoader,
training_methods=[TrainingMethod.FINE_TUNE, TrainingMethod.FINE_TUNE_VAE],
)
3 changes: 3 additions & 0 deletions modules/modelSampler/ChromaSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from modules.model.ChromaModel import ChromaModel
from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput
from modules.util import factory
from modules.util.config.SampleConfig import SampleConfig
from modules.util.enum.AudioFormat import AudioFormat
from modules.util.enum.FileType import FileType
Expand Down Expand Up @@ -188,3 +189,5 @@ def sample(
)

on_sample(sampler_output)

factory.register(BaseModelSampler, ChromaSampler, ModelType.CHROMA_1)
4 changes: 4 additions & 0 deletions modules/modelSampler/FluxSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from modules.model.FluxModel import FluxModel
from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput
from modules.util import factory
from modules.util.config.SampleConfig import SampleConfig
from modules.util.enum.AudioFormat import AudioFormat
from modules.util.enum.FileType import FileType
Expand Down Expand Up @@ -450,3 +451,6 @@ def sample(
)

on_sample(sampler_output)

factory.register(BaseModelSampler, FluxSampler, ModelType.FLUX_DEV_1)
factory.register(BaseModelSampler, FluxSampler, ModelType.FLUX_FILL_DEV_1)
3 changes: 3 additions & 0 deletions modules/modelSampler/HiDreamSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from modules.model.HiDreamModel import HiDreamModel
from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput
from modules.util import factory
from modules.util.config.SampleConfig import SampleConfig
from modules.util.enum.AudioFormat import AudioFormat
from modules.util.enum.FileType import FileType
Expand Down Expand Up @@ -191,3 +192,5 @@ def sample(
)

on_sample(sampler_output)

factory.register(BaseModelSampler, HiDreamSampler, ModelType.HI_DREAM_FULL)
3 changes: 3 additions & 0 deletions modules/modelSampler/HunyuanVideoSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from modules.model.HunyuanVideoModel import HunyuanVideoModel
from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput
from modules.util import factory
from modules.util.config.SampleConfig import SampleConfig
from modules.util.enum.AudioFormat import AudioFormat
from modules.util.enum.FileType import FileType
Expand Down Expand Up @@ -205,3 +206,5 @@ def sample(
)

on_sample(sampler_output)

factory.register(BaseModelSampler, HunyuanVideoSampler, ModelType.HUNYUAN_VIDEO)
5 changes: 4 additions & 1 deletion modules/modelSampler/PixArtAlphaSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from modules.model.PixArtAlphaModel import PixArtAlphaModel
from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput
from modules.util import create
from modules.util import create, factory
from modules.util.config.SampleConfig import SampleConfig
from modules.util.enum.AudioFormat import AudioFormat
from modules.util.enum.FileType import FileType
Expand Down Expand Up @@ -191,3 +191,6 @@ def sample(
)

on_sample(sampler_output)

factory.register(BaseModelSampler, PixArtAlphaSampler, ModelType.PIXART_ALPHA)
factory.register(BaseModelSampler, PixArtAlphaSampler, ModelType.PIXART_SIGMA)
3 changes: 3 additions & 0 deletions modules/modelSampler/QwenSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from modules.model.QwenModel import QwenModel
from modules.modelSampler.BaseModelSampler import BaseModelSampler, ModelSamplerOutput
from modules.util import factory
from modules.util.config.SampleConfig import SampleConfig
from modules.util.enum.AudioFormat import AudioFormat
from modules.util.enum.FileType import FileType
Expand Down Expand Up @@ -198,3 +199,5 @@ def sample(
)

on_sample(sampler_output)

factory.register(BaseModelSampler, QwenSampler, ModelType.QWEN)
Loading