Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Timestep Sampling Function from SD3 Branch to SD (dev base) #1671

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Sequence,
Tuple,
Union,
Callable,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob
Expand Down Expand Up @@ -3476,6 +3477,25 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)
parser.add_argument(
"--timestep_sampling",
choices=["uniform", "sigmoid", "shift", "flux_shift"],
default="uniform",
help="Method to sample timesteps: uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
" / タイムステップをサンプリングする方法:random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=1.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。",
)

parser.add_argument(
"--lowram",
Expand Down Expand Up @@ -5198,9 +5218,31 @@ def save_sd_model_on_train_end_common(
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)

def time_shift(mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b

def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, latents, device):
# Sample a random timestep for each image
b_size, _, h, w = latents.shape

if args.timestep_sampling != "uniform":
shift = args.discrete_flow_shift
logits_norm = torch.randn(b_size, device="cpu")
logits_norm = logits_norm * args.sigmoid_scale
timesteps = logits_norm.sigmoid()
if args.timestep_sampling == "flux_shift":
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
else:
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
timesteps = min_timestep + (timesteps * (max_timestep - min_timestep))
else:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")

if args.loss_type == "huber" or args.loss_type == "smooth_l1":
if args.huber_schedule == "exponential":
Expand All @@ -5223,7 +5265,6 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler,
timesteps = timesteps.long().to(device)
return timesteps, huber_c


def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
Expand All @@ -5238,12 +5279,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
)

# Sample a random timestep for each image
b_size = latents.shape[0]
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep

timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, latents, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand Down