Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 36 additions & 21 deletions library/lumina_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,47 +808,48 @@ def get_noisy_model_input_and_timesteps(
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.

Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
latents (Tensor): Latents.
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type

Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps
timesteps (reversed for Lumina: t=0 noise, t=1 image)
sigmas
"""
bsz, _, h, w = latents.shape
sigmas = None

if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)

timesteps = t * 1000.0

# Reverse for Lumina: t=0 is noise, t=1 is image
t_lumina = 1.0 - t
timesteps = t_lumina * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents

elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = (
logits_norm * args.sigmoid_scale
) # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)

t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
logits_norm = logits_norm * args.sigmoid_scale
t = logits_norm.sigmoid()
t = (t * shift) / (1 + (shift - 1) * t)

# Reverse for Lumina: t=0 is noise, t=1 is image
t_lumina = 1.0 - t
timesteps = t_lumina * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents

elif args.timestep_sampling == "nextdit_shift":
t = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
Expand All @@ -857,6 +858,15 @@ def get_noisy_model_input_and_timesteps(
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents

elif args.timestep_sampling == "lognorm":
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add a new timestep sampling method, it seems that you also need to add it to --timestep_sampling for add_lumina_train_arguments in lumina_train_util.py.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok i will add it into add_lumina_train_arguments in lumina_train_util in the next pull request

u = torch.normal(mean=0.0, std=1.0, size=(bsz,), device=device)
t = torch.sigmoid(u) # maps to [0,1]

timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents

else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
Expand All @@ -868,14 +878,19 @@ def get_noisy_model_input_and_timesteps(
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)

# Add noise according to flow matching.
sigmas = get_sigmas(
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
timesteps_normal = noise_scheduler.timesteps[indices].to(device=device)

# Reverse for Lumina convention
timesteps = noise_scheduler.config.num_train_timesteps - timesteps_normal

# Calculate sigmas with normal timesteps, then reverse interpolation
sigmas_normal = get_sigmas(
noise_scheduler, timesteps_normal, device, n_dim=latents.ndim, dtype=dtype
)
# Reverse sigma interpolation for Lumina
sigmas = 1.0 - sigmas_normal
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise

return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas


Expand Down