42
42
from torchvision import transforms
43
43
from transformers import CLIPTokenizer , CLIPTextModel , CLIPTextModelWithProjection
44
44
import transformers
45
- from diffusers .optimization import SchedulerType , TYPE_TO_SCHEDULER_FUNCTION
45
+ from diffusers .optimization import SchedulerType as DiffusersSchedulerType , TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION
46
+ from transformers .optimization import SchedulerType , TYPE_TO_SCHEDULER_FUNCTION
46
47
from diffusers import (
47
48
StableDiffusionPipeline ,
48
49
DDPMScheduler ,
@@ -2972,6 +2973,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
2972
2973
2973
2974
2974
2975
def add_optimizer_arguments (parser : argparse .ArgumentParser ):
2976
+ def int_or_float (value ):
2977
+ if value .endswith ('%' ):
2978
+ try :
2979
+ return float (value [:- 1 ]) / 100.0
2980
+ except ValueError :
2981
+ raise argparse .ArgumentTypeError (f"Value '{ value } ' is not a valid percentage" )
2982
+ try :
2983
+ float_value = float (value )
2984
+ if float_value >= 1 :
2985
+ return int (value )
2986
+ return float (value )
2987
+ except ValueError :
2988
+ raise argparse .ArgumentTypeError (f"'{ value } ' is not an int or float" )
2989
+
2975
2990
parser .add_argument (
2976
2991
"--optimizer_type" ,
2977
2992
type = str ,
@@ -3024,9 +3039,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
3024
3039
)
3025
3040
parser .add_argument (
3026
3041
"--lr_warmup_steps" ,
3027
- type = int ,
3042
+ type = int_or_float ,
3043
+ default = 0 ,
3044
+ help = "Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)" ,
3045
+ )
3046
+ parser .add_argument (
3047
+ "--lr_decay_steps" ,
3048
+ type = int_or_float ,
3028
3049
default = 0 ,
3029
- help = "Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0) " ,
3050
+ help = "Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps " ,
3030
3051
)
3031
3052
parser .add_argument (
3032
3053
"--lr_scheduler_num_cycles" ,
@@ -3046,6 +3067,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
3046
3067
help = "Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
3047
3068
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効" ,
3048
3069
)
3070
+ parser .add_argument (
3071
+ "--lr_scheduler_timescale" ,
3072
+ type = int ,
3073
+ default = None ,
3074
+ help = "Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" ,
3075
+ )
3076
+ parser .add_argument (
3077
+ "--lr_scheduler_min_lr_ratio" ,
3078
+ type = float ,
3079
+ default = None ,
3080
+ help = "The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" ,
3081
+ )
3049
3082
3050
3083
3051
3084
def add_training_arguments (parser : argparse .ArgumentParser , support_dreambooth : bool ):
@@ -4293,10 +4326,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
4293
4326
Unified API to get any scheduler from its name.
4294
4327
"""
4295
4328
name = args .lr_scheduler
4296
- num_warmup_steps : Optional [int ] = args .lr_warmup_steps
4297
4329
num_training_steps = args .max_train_steps * num_processes # * args.gradient_accumulation_steps
4330
+ num_warmup_steps : Optional [int ] = int (args .lr_warmup_steps * num_training_steps ) if isinstance (args .lr_warmup_steps , float ) else args .lr_warmup_steps
4331
+ num_decay_steps : Optional [int ] = int (args .lr_decay_steps * num_training_steps ) if isinstance (args .lr_decay_steps , float ) else args .lr_decay_steps
4332
+ num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
4298
4333
num_cycles = args .lr_scheduler_num_cycles
4299
4334
power = args .lr_scheduler_power
4335
+ timescale = args .lr_scheduler_timescale
4336
+ min_lr_ratio = args .lr_scheduler_min_lr_ratio
4300
4337
4301
4338
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
4302
4339
if args .lr_scheduler_args is not None and len (args .lr_scheduler_args ) > 0 :
@@ -4332,13 +4369,13 @@ def wrap_check_needless_num_warmup_steps(return_vals):
4332
4369
# logger.info(f"adafactor scheduler init lr {initial_lr}")
4333
4370
return wrap_check_needless_num_warmup_steps (transformers .optimization .AdafactorSchedule (optimizer , initial_lr ))
4334
4371
4335
- name = SchedulerType (name )
4336
- schedule_func = TYPE_TO_SCHEDULER_FUNCTION [name ]
4372
+ name = SchedulerType (name ) or DiffusersSchedulerType ( name )
4373
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION [name ] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION [ name ]
4337
4374
4338
4375
if name == SchedulerType .CONSTANT :
4339
4376
return wrap_check_needless_num_warmup_steps (schedule_func (optimizer , ** lr_scheduler_kwargs ))
4340
4377
4341
- if name == SchedulerType .PIECEWISE_CONSTANT :
4378
+ if name == DiffusersSchedulerType .PIECEWISE_CONSTANT :
4342
4379
return schedule_func (optimizer , ** lr_scheduler_kwargs ) # step_rules and last_epoch are given as kwargs
4343
4380
4344
4381
# All other schedulers require `num_warmup_steps`
@@ -4348,6 +4385,9 @@ def wrap_check_needless_num_warmup_steps(return_vals):
4348
4385
if name == SchedulerType .CONSTANT_WITH_WARMUP :
4349
4386
return schedule_func (optimizer , num_warmup_steps = num_warmup_steps , ** lr_scheduler_kwargs )
4350
4387
4388
+ if name == SchedulerType .INVERSE_SQRT :
4389
+ return schedule_func (optimizer , num_warmup_steps = num_warmup_steps , timescale = timescale , ** lr_scheduler_kwargs )
4390
+
4351
4391
# All other schedulers require `num_training_steps`
4352
4392
if num_training_steps is None :
4353
4393
raise ValueError (f"{ name } requires `num_training_steps`, please provide that argument." )
@@ -4366,7 +4406,31 @@ def wrap_check_needless_num_warmup_steps(return_vals):
4366
4406
optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps , power = power , ** lr_scheduler_kwargs
4367
4407
)
4368
4408
4369
- return schedule_func (optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps , ** lr_scheduler_kwargs )
4409
+ if name == SchedulerType .COSINE_WITH_MIN_LR :
4410
+ return schedule_func (
4411
+ optimizer ,
4412
+ num_warmup_steps = num_warmup_steps ,
4413
+ num_training_steps = num_training_steps ,
4414
+ num_cycles = num_cycles / 2 ,
4415
+ min_lr_rate = min_lr_ratio ,
4416
+ ** lr_scheduler_kwargs ,
4417
+ )
4418
+
4419
+ # All other schedulers require `num_decay_steps`
4420
+ if num_decay_steps is None :
4421
+ raise ValueError (f"{ name } requires `num_decay_steps`, please provide that argument." )
4422
+ if name == SchedulerType .WARMUP_STABLE_DECAY :
4423
+ return schedule_func (
4424
+ optimizer ,
4425
+ num_warmup_steps = num_warmup_steps ,
4426
+ num_stable_steps = num_stable_steps ,
4427
+ num_decay_steps = num_decay_steps ,
4428
+ num_cycles = num_cycles / 2 ,
4429
+ min_lr_ratio = min_lr_ratio if min_lr_ratio is not None else 0.0 ,
4430
+ ** lr_scheduler_kwargs ,
4431
+ )
4432
+
4433
+ return schedule_func (optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps , num_decay_steps = num_decay_steps , ** lr_scheduler_kwargs )
4370
4434
4371
4435
4372
4436
def prepare_dataset_args (args : argparse .Namespace , support_metadata : bool ):
0 commit comments