Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 72 additions & 8 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
StableDiffusionPipeline,
DDPMScheduler,
Expand Down Expand Up @@ -2971,6 +2972,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):


def add_optimizer_arguments(parser: argparse.ArgumentParser):
def int_or_float(value):
if value.endswith('%'):
try:
return float(value[:-1]) / 100.0
except ValueError:
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
try:
float_value = float(value)
if float_value >= 1:
return int(value)
return float(value)
except ValueError:
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")

parser.add_argument(
"--optimizer_type",
type=str,
Expand Down Expand Up @@ -3023,9 +3038,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
type=int_or_float,
default=0,
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
)
parser.add_argument(
"--lr_decay_steps",
type=int_or_float,
default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps",
)
parser.add_argument(
"--lr_scheduler_num_cycles",
Expand All @@ -3045,6 +3066,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
)
parser.add_argument(
"--lr_scheduler_timescale",
type=int,
default=None,
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`",
)
parser.add_argument(
"--lr_scheduler_min_lr_ratio",
type=float,
default=None,
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler",
)


def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
Expand Down Expand Up @@ -4292,10 +4325,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
Unified API to get any scheduler from its name.
"""
name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power
timescale = args.lr_scheduler_timescale
min_lr_ratio = args.lr_scheduler_min_lr_ratio

lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
Expand Down Expand Up @@ -4331,13 +4368,13 @@ def wrap_check_needless_num_warmup_steps(return_vals):
# logger.info(f"adafactor scheduler init lr {initial_lr}")
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))

name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
name = SchedulerType(name) or DiffusersSchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]

if name == SchedulerType.CONSTANT:
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))

if name == SchedulerType.PIECEWISE_CONSTANT:
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT:
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs

# All other schedulers require `num_warmup_steps`
Expand All @@ -4347,6 +4384,9 @@ def wrap_check_needless_num_warmup_steps(return_vals):
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)

if name == SchedulerType.INVERSE_SQRT:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)

# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
Expand All @@ -4365,7 +4405,31 @@ def wrap_check_needless_num_warmup_steps(return_vals):
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
)

return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
if name == SchedulerType.COSINE_WITH_MIN_LR:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles / 2,
min_lr_rate=min_lr_ratio,
**lr_scheduler_kwargs,
)

# All other schedulers require `num_decay_steps`
if num_decay_steps is None:
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
if name == SchedulerType.WARMUP_STABLE_DECAY:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_stable_steps=num_stable_steps,
num_decay_steps=num_decay_steps,
num_cycles=num_cycles / 2,
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
**lr_scheduler_kwargs,
)

return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs)


def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
accelerate==0.25.0
transformers==4.36.2
accelerate==0.30.0
transformers==4.41.2
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
Expand All @@ -16,7 +16,7 @@ altair==4.2.2
easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.20.1
huggingface-hub==0.23.3
# for Image utils
imagesize==1.4.1
# for BLIP captioning
Expand Down