Skip to content

Commit fd68703

Browse files
authored
Add New lr scheduler (#1393)
* add new lr scheduler * fix bugs and use num_cycles / 2 * Update requirements.txt * add num_cycles for min lr * keep PIECEWISE_CONSTANT * allow use float with warmup or decay ratio. * Update train_util.py
1 parent 62ec3e6 commit fd68703

File tree

2 files changed

+75
-11
lines changed

2 files changed

+75
-11
lines changed

library/train_util.py

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
from torchvision import transforms
4343
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
4444
import transformers
45-
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
45+
from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION
46+
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
4647
from diffusers import (
4748
StableDiffusionPipeline,
4849
DDPMScheduler,
@@ -2972,6 +2973,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
29722973

29732974

29742975
def add_optimizer_arguments(parser: argparse.ArgumentParser):
2976+
def int_or_float(value):
2977+
if value.endswith('%'):
2978+
try:
2979+
return float(value[:-1]) / 100.0
2980+
except ValueError:
2981+
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
2982+
try:
2983+
float_value = float(value)
2984+
if float_value >= 1:
2985+
return int(value)
2986+
return float(value)
2987+
except ValueError:
2988+
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
2989+
29752990
parser.add_argument(
29762991
"--optimizer_type",
29772992
type=str,
@@ -3024,9 +3039,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
30243039
)
30253040
parser.add_argument(
30263041
"--lr_warmup_steps",
3027-
type=int,
3042+
type=int_or_float,
3043+
default=0,
3044+
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
3045+
)
3046+
parser.add_argument(
3047+
"--lr_decay_steps",
3048+
type=int_or_float,
30283049
default=0,
3029-
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
3050+
help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps",
30303051
)
30313052
parser.add_argument(
30323053
"--lr_scheduler_num_cycles",
@@ -3046,6 +3067,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
30463067
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
30473068
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
30483069
)
3070+
parser.add_argument(
3071+
"--lr_scheduler_timescale",
3072+
type=int,
3073+
default=None,
3074+
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`",
3075+
)
3076+
parser.add_argument(
3077+
"--lr_scheduler_min_lr_ratio",
3078+
type=float,
3079+
default=None,
3080+
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler",
3081+
)
30493082

30503083

30513084
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -4293,10 +4326,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
42934326
Unified API to get any scheduler from its name.
42944327
"""
42954328
name = args.lr_scheduler
4296-
num_warmup_steps: Optional[int] = args.lr_warmup_steps
42974329
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
4330+
num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
4331+
num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
4332+
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
42984333
num_cycles = args.lr_scheduler_num_cycles
42994334
power = args.lr_scheduler_power
4335+
timescale = args.lr_scheduler_timescale
4336+
min_lr_ratio = args.lr_scheduler_min_lr_ratio
43004337

43014338
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
43024339
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
@@ -4332,13 +4369,13 @@ def wrap_check_needless_num_warmup_steps(return_vals):
43324369
# logger.info(f"adafactor scheduler init lr {initial_lr}")
43334370
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
43344371

4335-
name = SchedulerType(name)
4336-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
4372+
name = SchedulerType(name) or DiffusersSchedulerType(name)
4373+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
43374374

43384375
if name == SchedulerType.CONSTANT:
43394376
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
43404377

4341-
if name == SchedulerType.PIECEWISE_CONSTANT:
4378+
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT:
43424379
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
43434380

43444381
# All other schedulers require `num_warmup_steps`
@@ -4348,6 +4385,9 @@ def wrap_check_needless_num_warmup_steps(return_vals):
43484385
if name == SchedulerType.CONSTANT_WITH_WARMUP:
43494386
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
43504387

4388+
if name == SchedulerType.INVERSE_SQRT:
4389+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
4390+
43514391
# All other schedulers require `num_training_steps`
43524392
if num_training_steps is None:
43534393
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
@@ -4366,7 +4406,31 @@ def wrap_check_needless_num_warmup_steps(return_vals):
43664406
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
43674407
)
43684408

4369-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
4409+
if name == SchedulerType.COSINE_WITH_MIN_LR:
4410+
return schedule_func(
4411+
optimizer,
4412+
num_warmup_steps=num_warmup_steps,
4413+
num_training_steps=num_training_steps,
4414+
num_cycles=num_cycles / 2,
4415+
min_lr_rate=min_lr_ratio,
4416+
**lr_scheduler_kwargs,
4417+
)
4418+
4419+
# All other schedulers require `num_decay_steps`
4420+
if num_decay_steps is None:
4421+
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
4422+
if name == SchedulerType.WARMUP_STABLE_DECAY:
4423+
return schedule_func(
4424+
optimizer,
4425+
num_warmup_steps=num_warmup_steps,
4426+
num_stable_steps=num_stable_steps,
4427+
num_decay_steps=num_decay_steps,
4428+
num_cycles=num_cycles / 2,
4429+
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
4430+
**lr_scheduler_kwargs,
4431+
)
4432+
4433+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs)
43704434

43714435

43724436
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
accelerate==0.25.0
2-
transformers==4.36.2
1+
accelerate==0.30.0
2+
transformers==4.41.2
33
diffusers[torch]==0.25.0
44
ftfy==6.1.1
55
# albumentations==1.3.0
@@ -16,7 +16,7 @@ altair==4.2.2
1616
easygui==0.98.3
1717
toml==0.10.2
1818
voluptuous==0.13.1
19-
huggingface-hub==0.20.1
19+
huggingface-hub==0.23.3
2020
# for Image utils
2121
imagesize==1.4.1
2222
# for BLIP captioning

0 commit comments

Comments
 (0)