Skip to content

Conversation

minux302
Copy link

@minux302 minux302 commented Nov 30, 2024

I implemented Flux.1 ControlNet, based on the x-flux implementation.
I trained canny training based flux1-dev. The images below are the test results of training on my own dataset. This implementation appears to be working.

Test Results

condition, result
Screenshot from 2024-12-01 00-26-28

Dataset

  • 50000 photos, resized 1024x1024 (Resized without regard to the aspect ratio of the original image.)
  • caption by florence

Training Settings

I trained full-scratch for ControlNet. bs=1, accumulation=8, steps=10000. 30~40hours by H100

accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_control_net.py \
--pretrained_model_name_or_path flux1-dev.safetensors \
--clip_l clip_l.safetensors \
--t5xxl t5xxl_fp8_e4m3fn.safetensors \
--ae ae.safetensors \
--save_model_as safetensors --sdpa --persistent_data_loader_workers \
--max_data_loader_n_workers 1 --seed 42 --gradient_checkpointing --mixed_precision bf16 \
--optimizer_type adamw8bit --learning_rate 2e-5 \
--highvram --max_train_epochs 1 --save_every_n_steps 1000 --dataset_config /path/to/dataset.toml \
--output_dir exp001 --output_name flux-cn \
--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0  --deepspeed \
--sample_every_n_steps 1000 --sample_prompts /path/to/prompts.toml \
--log_with "wandb" --log_tracker_name "sd-scripts-flux-cn" 

deepspeed config

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 8
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

@sdbds
Copy link
Contributor

sdbds commented Dec 1, 2024

It looks very nice!It could be used for SD3.5 with a few more modifications.

@FurkanGozukara
Copy link

excellent work

can you show example of dataset?

if you could post some example of /path/to/dataset.toml taht would be amazing

@kohya-ss
Copy link
Owner

kohya-ss commented Dec 2, 2024

Thank you for this! We have confirmed that training can be done without any issues in Windows WSL environment with 512x512 resolution with 48GB VRAM.

@kohya-ss kohya-ss merged commit 09a3740 into kohya-ss:sd3 Dec 2, 2024
1 check passed
@kohya-ss kohya-ss mentioned this pull request Dec 2, 2024
25 tasks
@kohya-ss
Copy link
Owner

kohya-ss commented Dec 2, 2024

Thank you again for this great PR! I have a few questions.

Do you think it's ok to add flux.enable_gradient_checkpointing(...) to the following place?

    if args.gradient_checkpointing:
        controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)

Also, you said that you trained for 10,000 steps with gradient accumulation steps = 8. Did you specify 10,000 for the train steps (1,250 optimizing steps in actual) or 80,000 steps?

@minux302
Copy link
Author

minux302 commented Dec 3, 2024

if you could post some example of /path/to/dataset.toml

Here!

# dataset.toml

[general]
# shuffle_caption = true
resolution = [1024, 1024]
batch_size = 1
enable_bucket = false

[[datasets]]
    [[datasets.subsets]]
        image_dir = "/path/to/images"
        caption_extension = ".txt"
        conditioning_data_dir = "/path/to/conditional_images"

Do you think it's ok to add flux.enable_gradient_checkpointing(...) to the following place?

After model __init__, I think it may be OK.
ref: x-labs implementation
https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py#L101

Did you specify 10,000 for the train steps (1,250 optimizing steps in actual) or 80,000 steps?

80,000steps. After 2000 x 8 steps, it seems validation eesults are converted. Train blocks are only num_double_blocks=2, maybe convergence is faster because there are fewer layers to learn.

@kohya-ss
Copy link
Owner

kohya-ss commented Dec 3, 2024

After model __init__, I think it may be OK.

Thank you for your suggestion!

80,000steps. After 2000 x 8 steps, it seems validation eesults are converted. Train blocks are only num_double_blocks=2, maybe convergence is faster because there are fewer layers to learn.

Thank you for clarifying this as well. There are a lot of steps, but compared to how difficult it is to train FLUX.1, it's surprisingly few.

@FurkanGozukara
Copy link

@minux302 thanks a lot but an you give me few image example

@q654517651
Copy link

Hello, thank you very much for your work. Could you please provide an example of this file?
--sample_prompts /path/to/prompts.toml \

@seniorsolt
Copy link

Amazing news!

@minux302, сould you please clarify if the sudden converge, where the model starts following the ControlNet image, only happened at the end of the 80,000 steps?

@Johnson-yue
Copy link

@minux302 how to use the trained controlnet weight ?? any inference code for testing?

@seniorsolt
Copy link

@Johnson-yue try something from xlabs repo https://github.com/XLabs-AI/x-flux

@Johnson-yue
Copy link

@Johnson-yue try something from xlabs repo https://github.com/XLabs-AI/x-flux

OK, Thanks

nana0304 pushed a commit to nana0304/sd-scripts that referenced this pull request Jun 4, 2025
@yuweifanf
Copy link

@minux302 Hi, have you ever encountered the following problem?

F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants