Skip to content

Commit 6a5f87d

Browse files
committed
disable weighted captions in TI/XTI training
1 parent a876f2d commit 6a5f87d

3 files changed

+10
-9
lines changed

library/custom_train_functions.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,20 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
1818
return loss
1919

2020

21-
def add_custom_train_arguments(parser: argparse.ArgumentParser):
21+
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
2222
parser.add_argument(
2323
"--min_snr_gamma",
2424
type=float,
2525
default=None,
2626
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
2727
)
28-
parser.add_argument(
29-
"--weighted_captions",
30-
action="store_true",
31-
default=False,
32-
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.",
33-
)
28+
if support_weighted_captions:
29+
parser.add_argument(
30+
"--weighted_captions",
31+
action="store_true",
32+
default=False,
33+
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
34+
)
3435

3536

3637
re_attention = re.compile(

train_textual_inversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def setup_parser() -> argparse.ArgumentParser:
549549
train_util.add_training_arguments(parser, True)
550550
train_util.add_optimizer_arguments(parser)
551551
config_util.add_config_arguments(parser)
552-
custom_train_functions.add_custom_train_arguments(parser)
552+
custom_train_functions.add_custom_train_arguments(parser, False)
553553

554554
parser.add_argument(
555555
"--save_model_as",

train_textual_inversion_XTI.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def setup_parser() -> argparse.ArgumentParser:
603603
train_util.add_training_arguments(parser, True)
604604
train_util.add_optimizer_arguments(parser)
605605
config_util.add_config_arguments(parser)
606-
custom_train_functions.add_custom_train_arguments(parser)
606+
custom_train_functions.add_custom_train_arguments(parser, False)
607607

608608
parser.add_argument(
609609
"--save_model_as",

0 commit comments

Comments
 (0)