diff --git a/library/train_util.py b/library/train_util.py index 1c9f07bfa..3df01e3e7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -19,6 +19,7 @@ Sequence, Tuple, Union, + Callable, ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob @@ -5239,7 +5240,6 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, timesteps = time_shift(mu, 1.0, timesteps) else: timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) - t = timesteps.view(-1, 1, 1, 1) timesteps = min_timestep + (timesteps * (max_timestep - min_timestep)) else: timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")