Skip to content

Conversation

duongve13112002
Copy link

Description

This PR fixes the issue where Lumina's reversed timesteps (using t=0 as noise and t=1 as image) were not properly handled in some functions. As a result, certain timestep sampling methods (other than nextdit_shift) did not work as expected, causing the model to fail to learn even after thousands of steps.

The fix ensures that timestep handling is consistent with Lumina’s reversed convention.

In addition, this PR introduces a new timestep type named lognorm.

Changes

  • Fixed reversed timestep handling in lumina_train_util.py and related functions.
  • Adjusted affected methods so that they properly account for t=0 noise / t=1 image convention.
  • Added support for a new timestep type: lognorm.

… for lumina image v2 and add new timestep

Resolve the issue reported at kohya-ss#2201
 and introduce a new timestep type called "lognorm".
…ed_timesteps

Fix Lumina reversed timestep handling (kohya-ss#2201) and add "lognorm" sampling
Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. It looks good.
However, it seems that Diffusers calculates 1-timestep just before calling DiT. One idea would be to unify the timestep calculation to that method. What do you think?

https://github.com/huggingface/diffusers/blob/0a151115bbe493de74a4565e57352a0890e94777/src/diffusers/pipelines/lumina/pipeline_lumina.py#L846

t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents

elif args.timestep_sampling == "lognorm":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add a new timestep sampling method, it seems that you also need to add it to --timestep_sampling for add_lumina_train_arguments in lumina_train_util.py.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok i will add it into add_lumina_train_arguments in lumina_train_util in the next pull request

@urlesistiana
Copy link

urlesistiana commented Sep 29, 2025

Diffusers calculates 1-timestep just before calling DiT.

+1. I was also trying to reverse all timestep samplings, but found that there are other funcs outside lumina files rely on this timesteps, e.g. min_snr_gamma. Reversing it in one place makes things much easier. Won't have to change other funcs

lazy man forgot to open a pr, orz

@duongve13112002
Copy link
Author

duongve13112002 commented Sep 29, 2025

I have made the adjustments as you requested. Plus, the new timestep type which i implemented is similar to sigmoid so i deleted it and dont add to this pull request. In addition, I have also fixed the issue related to fine-tune the lumina model with multi-GPU. Moreover, according to the current fine-tune model code for Lumina, when args.blockwise_fused_optimizers is enabled, the model’s parameters are not being updated. At the moment, I don’t know how to fix this, so I have disabled this feature to prevent errors for users. Sorry for committing multiple times; I’m making changes using my tablet.

Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for update! I think it would be great to set the timestep to 1-t.

It seems that training is possible with the current code (without PR) when using next_dit, but is it correct to understand that no changes to next_dit are necessary?

Edit: time_shift seems to need to update.

Comment on lines +5528 to +5530
(
DistributedDataParallelKwargs(find_unused_parameters=True)
),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the purpose of this addition? I would appreciate an explanation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to my testing for full fintune lumina image model on multigpu you will get this error "expected gradient for parameter … but none found", so adding will handle this problem and train normal on multi-gpu without error

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation. This function is commonly called by all model training scripts, so any changes made here will require testing all models.

I think it might be a good idea to find out why Lumina needs this argument and solve that problem.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add a flag to enable this when fine-tuning all Lumina models. Could improve flexibility.

t = time_shift(mu, 1.0, t)

timesteps = t * 1000.0
timesteps = 1 - t * 1000.0
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also reversed the sampling of the ‘nextdit_shift’ timestep to synchronize it with the current code.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don’t need to change anything here because another function calls this one. Changing the code in this function could potentially break the training pipeline, for example, in this function
https://github.com/duongve13112002/sd-scripts/blob/4d24b71c1647f674951f482857c12c74a5a46440/library/lumina_train_util.py#L507-L537

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that get_schedule will not get the correct value unless modifying time_shift.

In this PR, the model input has been inverted to 1-t, so if you leave time_shift unmodified, the shift value will be inverted. In other words, the implementation of time_shift should be the same as in FLUX.1.

Copy link
Author

@duongve13112002 duongve13112002 Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok i will change it right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants