Skip to content

Commit 6dbfd47

Browse files
committed
Fix to work PIECEWISE_CONSTANT, update requirement.txt and README #1393
1 parent fd68703 commit 6dbfd47

File tree

3 files changed

+54
-25
lines changed

3 files changed

+54
-25
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
139139

140140
### Working in progress
141141

142+
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
143+
- transformers, accelerate and huggingface_hub are updated.
144+
- If you encounter any issues, please report them.
145+
146+
- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
147+
- See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
148+
- `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.
149+
150+
https://github.com/kohya-ss/sd-scripts/pull/1393
142151
- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds!
143152

144153
- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v!

library/train_util.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242
from torchvision import transforms
4343
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
4444
import transformers
45-
from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION
45+
from diffusers.optimization import (
46+
SchedulerType as DiffusersSchedulerType,
47+
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
48+
)
4649
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
4750
from diffusers import (
4851
StableDiffusionPipeline,
@@ -2974,7 +2977,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
29742977

29752978
def add_optimizer_arguments(parser: argparse.ArgumentParser):
29762979
def int_or_float(value):
2977-
if value.endswith('%'):
2980+
if value.endswith("%"):
29782981
try:
29792982
return float(value[:-1]) / 100.0
29802983
except ValueError:
@@ -3041,13 +3044,15 @@ def int_or_float(value):
30413044
"--lr_warmup_steps",
30423045
type=int_or_float,
30433046
default=0,
3044-
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
3047+
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
3048+
" / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
30453049
)
30463050
parser.add_argument(
30473051
"--lr_decay_steps",
30483052
type=int_or_float,
30493053
default=0,
3050-
help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps",
3054+
help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
3055+
" / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
30513056
)
30523057
parser.add_argument(
30533058
"--lr_scheduler_num_cycles",
@@ -3071,13 +3076,16 @@ def int_or_float(value):
30713076
"--lr_scheduler_timescale",
30723077
type=int,
30733078
default=None,
3074-
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`",
3079+
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
3080+
" / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
3081+
,
30753082
)
30763083
parser.add_argument(
30773084
"--lr_scheduler_min_lr_ratio",
30783085
type=float,
30793086
default=None,
3080-
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler",
3087+
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
3088+
" / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
30813089
)
30823090

30833091

@@ -4327,8 +4335,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
43274335
"""
43284336
name = args.lr_scheduler
43294337
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
4330-
num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
4331-
num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
4338+
num_warmup_steps: Optional[int] = (
4339+
int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
4340+
)
4341+
num_decay_steps: Optional[int] = (
4342+
int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
4343+
)
43324344
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
43334345
num_cycles = args.lr_scheduler_num_cycles
43344346
power = args.lr_scheduler_power
@@ -4369,15 +4381,17 @@ def wrap_check_needless_num_warmup_steps(return_vals):
43694381
# logger.info(f"adafactor scheduler init lr {initial_lr}")
43704382
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
43714383

4372-
name = SchedulerType(name) or DiffusersSchedulerType(name)
4373-
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
4384+
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
4385+
name = DiffusersSchedulerType(name)
4386+
schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
4387+
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
4388+
4389+
name = SchedulerType(name)
4390+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
43744391

43754392
if name == SchedulerType.CONSTANT:
43764393
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
43774394

4378-
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT:
4379-
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
4380-
43814395
# All other schedulers require `num_warmup_steps`
43824396
if num_warmup_steps is None:
43834397
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
@@ -4408,11 +4422,11 @@ def wrap_check_needless_num_warmup_steps(return_vals):
44084422

44094423
if name == SchedulerType.COSINE_WITH_MIN_LR:
44104424
return schedule_func(
4411-
optimizer,
4412-
num_warmup_steps=num_warmup_steps,
4413-
num_training_steps=num_training_steps,
4425+
optimizer,
4426+
num_warmup_steps=num_warmup_steps,
4427+
num_training_steps=num_training_steps,
44144428
num_cycles=num_cycles / 2,
4415-
min_lr_rate=min_lr_ratio,
4429+
min_lr_rate=min_lr_ratio,
44164430
**lr_scheduler_kwargs,
44174431
)
44184432

@@ -4421,16 +4435,22 @@ def wrap_check_needless_num_warmup_steps(return_vals):
44214435
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
44224436
if name == SchedulerType.WARMUP_STABLE_DECAY:
44234437
return schedule_func(
4424-
optimizer,
4425-
num_warmup_steps=num_warmup_steps,
4426-
num_stable_steps=num_stable_steps,
4427-
num_decay_steps=num_decay_steps,
4428-
num_cycles=num_cycles / 2,
4438+
optimizer,
4439+
num_warmup_steps=num_warmup_steps,
4440+
num_stable_steps=num_stable_steps,
4441+
num_decay_steps=num_decay_steps,
4442+
num_cycles=num_cycles / 2,
44294443
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
44304444
**lr_scheduler_kwargs,
44314445
)
44324446

4433-
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs)
4447+
return schedule_func(
4448+
optimizer,
4449+
num_warmup_steps=num_warmup_steps,
4450+
num_training_steps=num_training_steps,
4451+
num_decay_steps=num_decay_steps,
4452+
**lr_scheduler_kwargs,
4453+
)
44344454

44354455

44364456
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
accelerate==0.30.0
2-
transformers==4.41.2
2+
transformers==4.44.0
33
diffusers[torch]==0.25.0
44
ftfy==6.1.1
55
# albumentations==1.3.0
@@ -16,7 +16,7 @@ altair==4.2.2
1616
easygui==0.98.3
1717
toml==0.10.2
1818
voluptuous==0.13.1
19-
huggingface-hub==0.23.3
19+
huggingface-hub==0.24.5
2020
# for Image utils
2121
imagesize==1.4.1
2222
# for BLIP captioning

0 commit comments

Comments
 (0)