Skip to content
10 changes: 0 additions & 10 deletions modules/dataLoader/FluxBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,6 @@ def _output_modules(self, config: TrainConfig, model: FluxModel):
if not config.train_text_encoder_2_or_embedding():
output_names.append('text_encoder_2_hidden_state')

sort_names = output_names + ['concept']
output_names = output_names + [('concept.loss_weight', 'loss_weight')]

# add for calculating loss per concept
if config.validation:
output_names.append(('concept.name', 'concept_name'))
output_names.append(('concept.path', 'concept_path'))
output_names.append(('concept.seed', 'concept_seed'))

def before_cache_image_fun():
model.to(self.temp_device)
model.vae_to(self.train_device)
Expand All @@ -218,7 +209,6 @@ def before_cache_image_fun():

return self._output_modules_from_out_names(
output_names=output_names,
sort_names=sort_names,
config=config,
before_cache_image_fun=before_cache_image_fun,
use_conditioning_image=True,
Expand Down
10 changes: 0 additions & 10 deletions modules/dataLoader/HunyuanVideoBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,6 @@ def _output_modules(self, config: TrainConfig, model: HunyuanVideoModel):
if not config.train_text_encoder_2_or_embedding():
output_names.append('text_encoder_2_pooled_state')

sort_names = output_names + ['concept']
output_names = output_names + [('concept.loss_weight', 'loss_weight')]

# add for calculating loss per concept
if config.validation:
output_names.append(('concept.name', 'concept_name'))
output_names.append(('concept.path', 'concept_path'))
output_names.append(('concept.seed', 'concept_seed'))

def before_cache_image_fun():
model.to(self.temp_device)
model.vae_to(self.train_device)
Expand All @@ -210,7 +201,6 @@ def before_cache_image_fun():

return self._output_modules_from_out_names(
output_names=output_names,
sort_names=sort_names,
config=config,
before_cache_image_fun=before_cache_image_fun,
use_conditioning_image=True,
Expand Down
10 changes: 0 additions & 10 deletions modules/dataLoader/PixArtAlphaBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,6 @@ def _output_modules(self, config: TrainConfig, model: PixArtAlphaModel):
if not config.train_text_encoder_or_embedding():
output_names.append('text_encoder_hidden_state')

sort_names = output_names + ['concept']
output_names = output_names + [('concept.loss_weight', 'loss_weight')]

# add for calculating loss per concept
if config.validation:
output_names.append(('concept.name', 'concept_name'))
output_names.append(('concept.path', 'concept_path'))
output_names.append(('concept.seed', 'concept_seed'))

def before_cache_image_fun():
model.to(self.temp_device)
model.vae_to(self.train_device)
Expand All @@ -183,7 +174,6 @@ def before_cache_image_fun():

return self._output_modules_from_out_names(
output_names=output_names,
sort_names=sort_names,
config=config,
before_cache_image_fun=before_cache_image_fun,
use_conditioning_image=True,
Expand Down
10 changes: 0 additions & 10 deletions modules/dataLoader/SanaBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,6 @@ def _output_modules(self, config: TrainConfig, model: SanaModel):
if not config.train_text_encoder_or_embedding():
output_names.append('text_encoder_hidden_state')

sort_names = output_names + ['concept']
output_names = output_names + [('concept.loss_weight', 'loss_weight')]

# add for calculating loss per concept
if config.validation:
output_names.append(('concept.name', 'concept_name'))
output_names.append(('concept.path', 'concept_path'))
output_names.append(('concept.seed', 'concept_seed'))

def before_cache_image_fun():
model.to(self.temp_device)
model.vae_to(self.train_device)
Expand All @@ -176,7 +167,6 @@ def before_cache_image_fun():

return self._output_modules_from_out_names(
output_names=output_names,
sort_names=sort_names,
config=config,
before_cache_image_fun=before_cache_image_fun,
use_conditioning_image=True,
Expand Down
10 changes: 0 additions & 10 deletions modules/dataLoader/StableDiffusion3BaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,6 @@ def _output_modules(self, config: TrainConfig, model: StableDiffusion3Model):
if not config.train_text_encoder_3_or_embedding():
output_names.append('text_encoder_3_hidden_state')

sort_names = output_names + ['concept']
output_names = output_names + [('concept.loss_weight', 'loss_weight')]

# add for calculating loss per concept
if config.validation:
output_names.append(('concept.name', 'concept_name'))
output_names.append(('concept.path', 'concept_path'))
output_names.append(('concept.seed', 'concept_seed'))

def before_cache_image_fun():
model.to(self.temp_device)
model.vae_to(self.train_device)
Expand All @@ -241,7 +232,6 @@ def before_cache_image_fun():

return self._output_modules_from_out_names(
output_names=output_names,
sort_names=sort_names,
config=config,
before_cache_image_fun=before_cache_image_fun,
use_conditioning_image=True,
Expand Down
10 changes: 0 additions & 10 deletions modules/dataLoader/StableDiffusionBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,6 @@ def _output_modules(self, config: TrainConfig, model: StableDiffusionModel):
if not config.train_text_encoder_or_embedding():
output_names.append('text_encoder_hidden_state')

sort_names = output_names + ['concept']
output_names = output_names + [('concept.loss_weight', 'loss_weight')]

# add for calculating loss per concept
if config.validation:
output_names.append(('concept.name', 'concept_name'))
output_names.append(('concept.path', 'concept_path'))
output_names.append(('concept.seed', 'concept_seed'))

def before_cache_image_fun():
model.to(self.temp_device)
model.vae_to(self.train_device)
Expand All @@ -187,7 +178,6 @@ def before_cache_image_fun():

return self._output_modules_from_out_names(
output_names=output_names,
sort_names=sort_names,
config=config,
before_cache_image_fun=before_cache_image_fun,
use_conditioning_image=True,
Expand Down
1 change: 1 addition & 0 deletions modules/dataLoader/StableDiffusionFineTuneVaeDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __output_modules(self, config: TrainConfig):

sort_names = output_names + ['concept']
output_names = output_names + [('concept.loss_weight', 'loss_weight')]
output_names = output_names + [('concept.type', 'concept_type')]

# add for calculating loss per concept
if config.validation:
Expand Down
10 changes: 0 additions & 10 deletions modules/dataLoader/StableDiffusionXLBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,6 @@ def _output_modules(self, config: TrainConfig, model: StableDiffusionXLModel):
output_names.append('text_encoder_2_hidden_state')
output_names.append('text_encoder_2_pooled_state')

sort_names = output_names + ['concept']
output_names = output_names + [('concept.loss_weight', 'loss_weight')]

# add for calculating loss per concept
if config.validation:
output_names.append(('concept.name', 'concept_name'))
output_names.append(('concept.path', 'concept_path'))
output_names.append(('concept.seed', 'concept_seed'))

def before_cache_image_fun():
model.to(self.temp_device)
model.vae_to(self.train_device)
Expand All @@ -209,7 +200,6 @@ def before_cache_image_fun():

return self._output_modules_from_out_names(
output_names=output_names,
sort_names=sort_names,
config=config,
before_cache_image_fun=before_cache_image_fun,
use_conditioning_image=True,
Expand Down
10 changes: 0 additions & 10 deletions modules/dataLoader/WuerstchenBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,6 @@ def _output_modules(self, config: TrainConfig, model: WuerstchenModel):
if model.model_type.is_stable_cascade():
output_names.append('pooled_text_encoder_output')

sort_names = output_names + ['concept']
output_names = output_names + [('concept.loss_weight', 'loss_weight')]

# add for calculating loss per concept
if config.validation:
output_names.append(('concept.name', 'concept_name'))
output_names.append(('concept.path', 'concept_path'))
output_names.append(('concept.seed', 'concept_seed'))

def before_cache_image_fun():
model.to(self.temp_device)
model.effnet_encoder_to(self.train_device)
Expand All @@ -177,7 +168,6 @@ def before_cache_image_fun():

return self._output_modules_from_out_names(
output_names=output_names,
sort_names=sort_names,
config=config,
before_cache_image_fun=before_cache_image_fun,
autocast_context=[model.autocast_context],
Expand Down
3 changes: 2 additions & 1 deletion modules/dataLoader/mixin/DataLoaderMgdsMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from modules.util.config.ConceptConfig import ConceptConfig
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ConceptType import ConceptType
from modules.util.TrainProgress import TrainProgress

from mgds.MGDS import MGDS
Expand All @@ -26,7 +27,7 @@ def _create_mgds(
concepts = [ConceptConfig.default_values().from_dict(c) for c in json.load(f)]

# choose all validation concepts, or none of them, depending on is_validation
concepts = [concept for concept in concepts if concept.validation_concept == is_validation]
concepts = [concept for concept in concepts if (ConceptType(concept.type) == ConceptType.VALIDATION) == is_validation]

# convert before passing to MGDS
concepts = [c.to_dict() for c in concepts]
Expand Down
15 changes: 13 additions & 2 deletions modules/dataLoader/mixin/DataLoaderText2ImageMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,26 @@ def _inpainting_modules(self, config: TrainConfig):

def _output_modules_from_out_names(
self,
output_names: list[str],
sort_names: list[str],
output_names: list[str | tuple[str, str]],
config: TrainConfig,
before_cache_image_fun: Callable[[], None] | None = None,
use_conditioning_image: bool = False,
vae: AutoencoderKL | None = None,
autocast_context: list[torch.autocast | None] = None,
train_dtype: DataType | None = None,
):
sort_names = output_names + ['concept']

output_names = output_names + [
('concept.loss_weight', 'loss_weight'),
('concept.type', 'concept_type'),
]

if config.validation:
output_names.append(('concept.name', 'concept_name'))
output_names.append(('concept.path', 'concept_path'))
output_names.append(('concept.seed', 'concept_seed'))

mask_remove = RandomLatentMaskRemove(
latent_mask_name='latent_mask', latent_conditioning_image_name='latent_conditioning_image' if use_conditioning_image else None,
replace_probability=config.unmasked_probability, vae=vae,
Expand Down
5 changes: 5 additions & 0 deletions modules/model/BaseModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from uuid import uuid4

from modules.module.EMAModule import EMAModuleWrapper
from modules.module.LoRAModule import LoRAModuleWrapper
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.DataType import DataType
from modules.util.enum.ModelType import ModelType
Expand Down Expand Up @@ -101,6 +102,10 @@ def to(self, device: torch.device):
def eval(self):
pass

@abstractmethod
def adapters(self) -> list[LoRAModuleWrapper]:
pass

@staticmethod
def _add_embeddings_to_prompt(
additional_embeddings: list[BaseModelEmbedding],
Expand Down
7 changes: 7 additions & 0 deletions modules/model/FluxModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def __init__(
self.transformer_lora = None
self.lora_state_dict = None

def adapters(self) -> list[LoRAModuleWrapper]:
return [a for a in [
self.text_encoder_1_lora,
self.text_encoder_2_lora,
self.transformer_lora,
] if a is not None]

def all_embeddings(self) -> list[FluxModelEmbedding]:
return self.additional_embeddings \
+ ([self.embedding] if self.embedding is not None else [])
Expand Down
7 changes: 7 additions & 0 deletions modules/model/HunyuanVideoModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ def __init__(
self.transformer_lora = None
self.lora_state_dict = None

def adapters(self) -> list[LoRAModuleWrapper]:
return [a for a in [
self.text_encoder_1_lora,
self.text_encoder_2_lora,
self.transformer_lora,
] if a is not None]

def all_embeddings(self) -> list[HunyuanVideoModelEmbedding]:
return self.additional_embeddings \
+ ([self.embedding] if self.embedding is not None else [])
Expand Down
6 changes: 6 additions & 0 deletions modules/model/PixArtAlphaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def __init__(
self.transformer_lora = None
self.lora_state_dict = None

def adapters(self) -> list[LoRAModuleWrapper]:
return [a for a in [
self.text_encoder_lora,
self.transformer_lora,
] if a is not None]

def all_embeddings(self) -> list[PixArtAlphaModelEmbedding]:
return self.additional_embeddings \
+ ([self.embedding] if self.embedding is not None else [])
Expand Down
6 changes: 6 additions & 0 deletions modules/model/SanaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def __init__(
self.transformer_lora = None
self.lora_state_dict = None

def adapters(self) -> list[LoRAModuleWrapper]:
return [a for a in [
self.text_encoder_lora,
self.transformer_lora,
] if a is not None]

def all_embeddings(self) -> list[SanaModelEmbedding]:
return self.additional_embeddings \
+ ([self.embedding] if self.embedding is not None else [])
Expand Down
8 changes: 8 additions & 0 deletions modules/model/StableDiffusion3Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ def __init__(
self.transformer_lora = None
self.lora_state_dict = None

def adapters(self) -> list[LoRAModuleWrapper]:
return [a for a in [
self.text_encoder_1_lora,
self.text_encoder_2_lora,
self.text_encoder_3_lora,
self.transformer_lora,
] if a is not None]

def all_embeddings(self) -> list[StableDiffusion3ModelEmbedding]:
return self.additional_embeddings \
+ ([self.embedding] if self.embedding is not None else [])
Expand Down
6 changes: 6 additions & 0 deletions modules/model/StableDiffusionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def __init__(
self.sd_config = None
self.sd_config_filename = None

def adapters(self) -> list[LoRAModuleWrapper]:
return [a for a in [
self.text_encoder_lora,
self.unet_lora,
] if a is not None]

def all_embeddings(self) -> list[StableDiffusionModelEmbedding]:
return self.additional_embeddings \
+ ([self.embedding] if self.embedding is not None else [])
Expand Down
7 changes: 7 additions & 0 deletions modules/model/StableDiffusionXLModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def __init__(
self.sd_config = None
self.sd_config_filename = None

def adapters(self) -> list[LoRAModuleWrapper]:
return [a for a in [
self.text_encoder_1_lora,
self.text_encoder_2_lora,
self.unet_lora,
] if a is not None]

def all_embeddings(self) -> list[StableDiffusionXLModelEmbedding]:
return self.additional_embeddings \
+ ([self.embedding] if self.embedding is not None else [])
Expand Down
6 changes: 6 additions & 0 deletions modules/model/WuerstchenModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def __init__(
self.prior_prior_lora = None
self.lora_state_dict = None

def adapters(self) -> list[LoRAModuleWrapper]:
return [a for a in [
self.prior_text_encoder_lora,
self.prior_prior_lora,
] if a is not None]

def all_embeddings(self) -> list[WuerstchenModelEmbedding]:
return self.additional_embeddings \
+ ([self.embedding] if self.embedding is not None else [])
Expand Down
Loading