Skip to content

Commit fc37437

Browse files
committed
Allow negative learning rate
This can be used to train away from a group of images you don't want As this moves the model away from a point instead of towards it, the change in the model is unbounded So, don't set it too low. -4e-7 seemed to work well.
1 parent 71e2c91 commit fc37437

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

sdxl_train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
272272
# 学習を準備する:モデルを適切な状態にする
273273
if args.gradient_checkpointing:
274274
unet.enable_gradient_checkpointing()
275-
train_unet = args.learning_rate > 0
275+
train_unet = args.learning_rate != 0
276276
train_text_encoder1 = False
277277
train_text_encoder2 = False
278278

@@ -284,8 +284,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
284284
text_encoder2.gradient_checkpointing_enable()
285285
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
286286
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
287-
train_text_encoder1 = lr_te1 > 0
288-
train_text_encoder2 = lr_te2 > 0
287+
train_text_encoder1 = lr_te1 != 0
288+
train_text_encoder2 = lr_te2 != 0
289289

290290
# caching one text encoder output is not supported
291291
if not train_text_encoder1:

0 commit comments

Comments
 (0)