Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
40a14df
generic model saver
dxqb Dec 17, 2025
7a8df73
factory for model components
dxqb Dec 17, 2025
c85e20e
merge
dxqb Dec 17, 2025
f8d55f3
remove unused code
dxqb Dec 17, 2025
1fef80b
simplify debug code
dxqb Dec 17, 2025
f1942b2
type hint
dxqb Dec 17, 2025
900c284
missing file
dxqb Dec 19, 2025
9670bf4
refactor dataloaders
dxqb Dec 23, 2025
7d44d1c
undo unnecessary code change
dxqb Dec 23, 2025
3de1116
undo unnecessary code change
dxqb Dec 23, 2025
3fba6fb
revert prepare vae
dxqb Dec 23, 2025
27d57e6
more models
dxqb Dec 23, 2025
9ce5de4
convert util
dxqb Dec 30, 2025
faae3fc
dependency
dxqb Dec 30, 2025
68b7851
remove exception
dxqb Dec 30, 2025
e5a55b2
Flux
dxqb Dec 30, 2025
73b2712
Merge branch 'pr-1210' into flux2_base
dxqb Dec 30, 2025
0483c6f
Merge branch 'pr-1211' into flux2_base
dxqb Dec 30, 2025
a97da04
Merge branch 'pr-1212' into flux2_base
dxqb Dec 30, 2025
61ec0d8
merge
dxqb Dec 30, 2025
a7807d7
Merge branch 'pr-1236' into flux2_base
dxqb Dec 30, 2025
e33bc47
merge
dxqb Dec 30, 2025
17056c5
mgds dependency
dxqb Dec 30, 2025
e50970f
Flux.Klein support
dxqb Jan 16, 2026
7c23519
mgds dependency
dxqb Jan 16, 2026
bada204
Merge branch 'upstream' into flux2_klein
dxqb Jan 16, 2026
2870d01
Merge branch 'upstream' into flux2
dxqb Jan 16, 2026
b0bd028
Merge branch 'upstream' into flux2
dxqb Jan 16, 2026
210a9e4
Merge branch 'upstream' into flux2_klein
dxqb Jan 16, 2026
f1f9320
fix arguments to validation data loader
dxqb Jan 17, 2026
1ebe8c9
Merge branch 'dataloader' into flux2_base
dxqb Jan 17, 2026
85eab22
Merge branch 'flux2_base' into flux2
dxqb Jan 17, 2026
de9f3be
Merge branch 'flux2' into flux2_klein
dxqb Jan 17, 2026
9f82679
Qwen3 tied weights workaround
dxqb Jan 17, 2026
a3be5ee
unpatchify, to match the shape of masks
dxqb Jan 19, 2026
e94354a
Merge branch 'flux2' into flux2_klein
dxqb Jan 19, 2026
792510a
rename Comfy and remove filter, because UI is not updated when you ch…
dxqb Jan 19, 2026
f7c4d8f
Merge branch 'flux2' into flux2_klein
dxqb Jan 19, 2026
4c0de7b
disable Comfy LoRA format, change prefix
dxqb Jan 24, 2026
f995e54
Merge branch 'merge' into flux2_klein
dxqb Jan 24, 2026
7beeab6
Fix Readme typo
O-J1 Jan 28, 2026
ecc4da2
Fix syntax error in preset
dxqb Jan 29, 2026
922a4f6
workaround for GGUF loading
Jan 29, 2026
01836db
upgrade diffusers
Jan 30, 2026
5d3bd18
merge
Feb 1, 2026
bdc10ca
pre-commit
Feb 1, 2026
a8640ed
disable embedding training for flux2
Feb 1, 2026
ae54674
dynamic timestep shifting
Feb 1, 2026
79f4405
remove comment
Feb 1, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ OneTrainer is a one-stop solution for all your Diffusion training needs.
> sudo apt-get install libgl1
> ```
>
> Additionally it's been reported Alpine, Arch and Xubuntuu Linux may be missing `tkinter`. Install it via `apk add py3-tk` for Alpine and `sudo pacman -S tk` for Arch.
> Additionally it's been reported Alpine, Arch and Xubuntu Linux may be missing `tkinter`. Install it via `apk add py3-tk` for Alpine and `sudo pacman -S tk` for Arch.
## Updating
Expand Down
164 changes: 164 additions & 0 deletions modules/dataLoader/Flux2BaseDataLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import os

from modules.dataLoader.BaseDataLoader import BaseDataLoader
from modules.dataLoader.mixin.DataLoaderText2ImageMixin import DataLoaderText2ImageMixin
from modules.model.Flux2Model import (
MISTRAL_HIDDEN_STATES_LAYERS,
MISTRAL_SYSTEM_MESSAGE,
QWEN3_HIDDEN_STATES_LAYERS,
Flux2Model,
mistral_format_input,
qwen3_format_input,
)
from modules.modelSetup.BaseFlux2Setup import BaseFlux2Setup
from modules.util import factory
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import ModelType
from modules.util.TrainProgress import TrainProgress

from mgds.pipelineModules.DecodeTokens import DecodeTokens
from mgds.pipelineModules.DecodeVAE import DecodeVAE
from mgds.pipelineModules.EncodeMistralText import EncodeMistralText
from mgds.pipelineModules.EncodeQwenText import EncodeQwenText
from mgds.pipelineModules.EncodeVAE import EncodeVAE
from mgds.pipelineModules.RescaleImageChannels import RescaleImageChannels
from mgds.pipelineModules.SampleVAEDistribution import SampleVAEDistribution
from mgds.pipelineModules.SaveImage import SaveImage
from mgds.pipelineModules.SaveText import SaveText
from mgds.pipelineModules.ScaleImage import ScaleImage
from mgds.pipelineModules.Tokenize import Tokenize


class Flux2BaseDataLoader(
BaseDataLoader,
DataLoaderText2ImageMixin,
):
def _preparation_modules(self, config: TrainConfig, model: Flux2Model):
rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1)
encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype())
image_sample = SampleVAEDistribution(in_name='latent_image_distribution', out_name='latent_image', mode='mean')
downscale_mask = ScaleImage(in_name='mask', out_name='latent_mask', factor=0.125)
if model.is_dev():
tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=config.text_encoder_sequence_length,
apply_chat_template = lambda caption: mistral_format_input([caption], MISTRAL_SYSTEM_MESSAGE), apply_chat_template_kwargs = {'add_generation_prompt': False},
)
encode_prompt = EncodeMistralText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask',
text_encoder=model.text_encoder, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype(),
hidden_state_output_index=MISTRAL_HIDDEN_STATES_LAYERS,
)
else: #klein
tokenize_prompt = Tokenize(in_name='prompt', tokens_out_name='tokens', mask_out_name='tokens_mask', tokenizer=model.tokenizer, max_token_length=config.text_encoder_sequence_length,
apply_chat_template = lambda caption: qwen3_format_input(caption), apply_chat_template_kwargs = {'add_generation_prompt': True, 'enable_thinking': False}
)
if config.dataloader_threads > 1:
#TODO this code is copied from Z-Image, which also uses Qwen3ForCausalLM. The leak issue probably also applies for Flux2.Klein:
raise NotImplementedError("Multiple data loader threads are not supported due to an issue with the transformers library: https://github.com/huggingface/transformers/issues/42673")
encode_prompt = EncodeQwenText(tokens_name='tokens', tokens_attention_mask_in_name='tokens_mask', hidden_state_out_name='text_encoder_hidden_state', tokens_attention_mask_out_name='tokens_mask',
text_encoder=model.text_encoder, hidden_state_output_index=QWEN3_HIDDEN_STATES_LAYERS, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype())


modules = [rescale_image, encode_image, image_sample]
if config.masked_training or config.model_type.has_mask_input():
modules.append(downscale_mask)

modules += [tokenize_prompt, encode_prompt]
return modules

def _cache_modules(self, config: TrainConfig, model: Flux2Model, model_setup: BaseFlux2Setup):
image_split_names = ['latent_image', 'original_resolution', 'crop_offset']

if config.masked_training or config.model_type.has_mask_input():
image_split_names.append('latent_mask')

image_aggregate_names = ['crop_resolution', 'image_path']

text_split_names = []

sort_names = image_aggregate_names + image_split_names + [
'prompt', 'tokens', 'tokens_mask', 'text_encoder_hidden_state',
'concept'
]

text_split_names += ['tokens', 'tokens_mask', 'text_encoder_hidden_state']

return self._cache_modules_from_names(
model, model_setup,
image_split_names=image_split_names,
image_aggregate_names=image_aggregate_names,
text_split_names=text_split_names,
sort_names=sort_names,
config=config,
text_caching=True,
)

def _output_modules(self, config: TrainConfig, model: Flux2Model, model_setup: BaseFlux2Setup):
output_names = [
'image_path', 'latent_image',
'prompt',
'tokens',
'tokens_mask',
'original_resolution', 'crop_resolution', 'crop_offset',
]

if config.masked_training or config.model_type.has_mask_input():
output_names.append('latent_mask')

output_names.append('text_encoder_hidden_state')

return self._output_modules_from_out_names(
model, model_setup,
output_names=output_names,
config=config,
use_conditioning_image=False,
vae=model.vae,
autocast_context=[model.autocast_context],
train_dtype=model.train_dtype,
)

def _debug_modules(self, config: TrainConfig, model: Flux2Model):
debug_dir = os.path.join(config.debug_dir, "dataloader")

def before_save_fun():
model.vae_to(self.train_device)

decode_image = DecodeVAE(in_name='latent_image', out_name='decoded_image', vae=model.vae, autocast_contexts=[model.autocast_context], dtype=model.train_dtype.torch_dtype())
upscale_mask = ScaleImage(in_name='latent_mask', out_name='decoded_mask', factor=8)
decode_prompt = DecodeTokens(in_name='tokens', out_name='decoded_prompt', tokenizer=model.tokenizer)
save_image = SaveImage(image_in_name='decoded_image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1, before_save_fun=before_save_fun)
# SaveImage(image_in_name='latent_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1, before_save_fun=before_save_fun)
save_mask = SaveImage(image_in_name='decoded_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1, before_save_fun=before_save_fun)
save_prompt = SaveText(text_in_name='decoded_prompt', original_path_in_name='image_path', path=debug_dir, before_save_fun=before_save_fun)

# These modules don't really work, since they are inserted after a sorting operation that does not include this data
# SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1),
# SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1),

modules = []

modules.append(decode_image)
modules.append(save_image)

if config.masked_training or config.model_type.has_mask_input():
modules.append(upscale_mask)
modules.append(save_mask)

modules.append(decode_prompt)
modules.append(save_prompt)

return modules

def _create_dataset(
self,
config: TrainConfig,
model: Flux2Model,
model_setup: BaseFlux2Setup,
train_progress: TrainProgress,
is_validation: bool = False,
):
return DataLoaderText2ImageMixin._create_dataset(self,
config, model, model_setup, train_progress, is_validation,
aspect_bucketing_quantization=64,
)


factory.register(BaseDataLoader, Flux2BaseDataLoader, ModelType.FLUX_2)
5 changes: 5 additions & 0 deletions modules/model/ChromaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,8 @@ def unpack_latents(self, latents, height: int, width: int):
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)

return latents


from modules.util import factory

factory.register(BaseModel, ChromaModel, ModelType.CHROMA_1)
Loading