Skip to content

Commit 418e5b3

Browse files
authored
Sd3 freeze x_block (kohya-ss#1417)
* Update sd3_train.py * add freeze block lr * Update train_util.py * update
1 parent d8a3b4b commit 418e5b3

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

library/train_util.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3246,6 +3246,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
32463246
default=None,
32473247
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
32483248
)
3249+
parser.add_argument(
3250+
"--num_last_block_to_freeze",
3251+
type=int,
3252+
default=None,
3253+
help="num_last_block_to_freeze",
3254+
)
32493255

32503256

32513257
def add_optimizer_arguments(parser: argparse.ArgumentParser):
@@ -5758,6 +5764,21 @@ def sample_image_inference(
57585764
pass
57595765

57605766

5767+
def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"):
5768+
5769+
filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name]
5770+
print(f"filtered_blocks: {len(filtered_blocks)}")
5771+
5772+
num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze)
5773+
5774+
print(f"freeze_blocks: {num_blocks_to_freeze}")
5775+
5776+
start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze)
5777+
5778+
for i in range(start_freezing_from, len(filtered_blocks)):
5779+
_, param = filtered_blocks[i]
5780+
param.requires_grad = False
5781+
57615782
# endregion
57625783

57635784

sd3_train.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,19 @@ def train(args):
368368
vae.eval()
369369
vae.to(accelerator.device, dtype=vae_dtype)
370370

371+
mmdit.requires_grad_(train_mmdit)
372+
if not train_mmdit:
373+
mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
374+
375+
if args.num_last_block_to_freeze:
376+
train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze)
377+
371378
training_models = []
372379
params_to_optimize = []
373380
# if train_unet:
374381
training_models.append(mmdit)
375382
# if block_lrs is None:
376-
params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate})
383+
params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate})
377384
# else:
378385
# params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))
379386

0 commit comments

Comments
 (0)