Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds!

- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf!
- This resolves the issue where the order of data loading from DataSet changes when resuming training.
- Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before).
- If `--resume` is specified, the step saved in the state is used.
- Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`).

- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
- It seems that the model file loading is faster in the WSL environment etc.
- Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
Expand Down Expand Up @@ -235,6 +241,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!

- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。

- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。
- これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。
- `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます)
- `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。
- `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。

- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
- `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。
Expand Down
14 changes: 12 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,16 @@ def set_caching_mode(self, mode):

def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch
if epoch > self.current_epoch:
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
num_epochs = epoch - self.current_epoch
for _ in range(num_epochs):
self.current_epoch += 1
self.shuffle_buckets()
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
else:
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
self.current_epoch = epoch

def set_current_step(self, step):
self.current_step = step
Expand Down Expand Up @@ -5553,6 +5561,8 @@ def add(self, *, epoch: int, step: int, loss: float) -> None:
if epoch == 0:
self.loss_list.append(loss)
else:
while len(self.loss_list) <= step:
self.loss_list.append(0.0)
self.loss_total -= self.loss_list[step]
self.loss_list[step] = loss
self.loss_total += loss
Expand Down
105 changes: 102 additions & 3 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,15 @@ def save_model_hook(models, weights, output_dir):
weights.pop(i)
# print(f"save model hook: {len(weights)} weights will be saved")

# save current ecpoch and step
train_state_file = os.path.join(output_dir, "train_state.json")
# +1 is needed because the state is saved before current_step is set from global_step
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
with open(train_state_file, "w", encoding="utf-8") as f:
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)

steps_from_state = None

def load_model_hook(models, input_dir):
# remove models except network
remove_indices = []
Expand All @@ -514,6 +523,15 @@ def load_model_hook(models, input_dir):
models.pop(i)
# print(f"load model hook: {len(models)} models will be loaded")

# load current epoch and step to
nonlocal steps_from_state
train_state_file = os.path.join(input_dir, "train_state.json")
if os.path.exists(train_state_file):
with open(train_state_file, "r", encoding="utf-8") as f:
data = json.load(f)
steps_from_state = data["current_step"]
logger.info(f"load train state from {train_state_file}: {data}")

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

Expand Down Expand Up @@ -757,7 +775,54 @@ def load_model_hook(models, input_dir):
if key in metadata:
minimum_metadata[key] = metadata[key]

progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
# calculate steps to skip when resuming or starting from a specific step
initial_step = 0
if args.initial_epoch is not None or args.initial_step is not None:
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
if steps_from_state is not None:
logger.warning(
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
)
if args.initial_step is not None:
initial_step = args.initial_step
else:
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
initial_step = (args.initial_epoch - 1) * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
else:
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
if steps_from_state is not None:
initial_step = steps_from_state
steps_from_state = None

if initial_step > 0:
assert (
args.max_train_steps > initial_step
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"

progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)

epoch_to_start = 0
if initial_step > 0:
if args.skip_until_initial_step:
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
if not args.resume:
logger.info(
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
)
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
initial_step *= args.gradient_accumulation_steps

# set epoch to start to make initial_step less than len(train_dataloader)
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
else:
# if not, only epoch no is skipped for informative purpose
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
initial_step = 0 # do not skip

global_step = 0

noise_scheduler = DDPMScheduler(
Expand Down Expand Up @@ -816,16 +881,31 @@ def remove_model(old_ckpt_name):
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

# training loop
for epoch in range(num_train_epochs):
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
global_step = initial_step

for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

metadata["ss_epoch"] = str(epoch + 1)

accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)

for step, batch in enumerate(train_dataloader):
skipped_dataloader = None
if initial_step > 0:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
initial_step = 1

for step, batch in enumerate(skipped_dataloader or train_dataloader):
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
continue

with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)

Expand Down Expand Up @@ -1126,6 +1206,25 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--skip_until_initial_step",
action="store_true",
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
)
parser.add_argument(
"--initial_epoch",
type=int,
default=None,
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
)
parser.add_argument(
"--initial_step",
type=int,
default=None,
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
)
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
Expand Down