Skip to content

Commit f33e733

Browse files
committed
Refactor code to use keyword arguments in train_util.py
1 parent 3fb639e commit f33e733

File tree

1 file changed

+6
-114
lines changed

1 file changed

+6
-114
lines changed

library/train_util.py

Lines changed: 6 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -416,49 +416,13 @@ def __init__(
416416
class_tokens: Optional[str],
417417
caption_extension: str,
418418
cache_info: bool,
419-
num_repeats,
420-
shuffle_caption,
421-
caption_separator: str,
422-
keep_tokens,
423-
keep_tokens_separator,
424-
secondary_separator,
425-
enable_wildcard,
426-
color_aug,
427-
flip_aug,
428-
face_crop_aug_range,
429-
random_crop,
430-
caption_dropout_rate,
431-
caption_dropout_every_n_epochs,
432-
caption_tag_dropout_rate,
433-
caption_prefix,
434-
caption_suffix,
435-
token_warmup_min,
436-
token_warmup_step,
437-
alpha_mask,
419+
**kwargs,
438420
) -> None:
439421
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
440422

441423
super().__init__(
442424
image_dir,
443-
num_repeats,
444-
shuffle_caption,
445-
caption_separator,
446-
keep_tokens,
447-
keep_tokens_separator,
448-
secondary_separator,
449-
enable_wildcard,
450-
color_aug,
451-
flip_aug,
452-
face_crop_aug_range,
453-
random_crop,
454-
caption_dropout_rate,
455-
caption_dropout_every_n_epochs,
456-
caption_tag_dropout_rate,
457-
caption_prefix,
458-
caption_suffix,
459-
token_warmup_min,
460-
token_warmup_step,
461-
alpha_mask,
425+
**kwargs,
462426
)
463427

464428
self.is_reg = is_reg
@@ -479,49 +443,13 @@ def __init__(
479443
self,
480444
image_dir,
481445
metadata_file: str,
482-
num_repeats,
483-
shuffle_caption,
484-
caption_separator,
485-
keep_tokens,
486-
keep_tokens_separator,
487-
secondary_separator,
488-
enable_wildcard,
489-
color_aug,
490-
flip_aug,
491-
face_crop_aug_range,
492-
random_crop,
493-
caption_dropout_rate,
494-
caption_dropout_every_n_epochs,
495-
caption_tag_dropout_rate,
496-
caption_prefix,
497-
caption_suffix,
498-
token_warmup_min,
499-
token_warmup_step,
500-
alpha_mask,
446+
**kwargs,
501447
) -> None:
502448
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
503449

504450
super().__init__(
505451
image_dir,
506-
num_repeats,
507-
shuffle_caption,
508-
caption_separator,
509-
keep_tokens,
510-
keep_tokens_separator,
511-
secondary_separator,
512-
enable_wildcard,
513-
color_aug,
514-
flip_aug,
515-
face_crop_aug_range,
516-
random_crop,
517-
caption_dropout_rate,
518-
caption_dropout_every_n_epochs,
519-
caption_tag_dropout_rate,
520-
caption_prefix,
521-
caption_suffix,
522-
token_warmup_min,
523-
token_warmup_step,
524-
alpha_mask,
452+
**kwargs,
525453
)
526454

527455
self.metadata_file = metadata_file
@@ -539,49 +467,13 @@ def __init__(
539467
conditioning_data_dir: str,
540468
caption_extension: str,
541469
cache_info: bool,
542-
num_repeats,
543-
shuffle_caption,
544-
caption_separator,
545-
keep_tokens,
546-
keep_tokens_separator,
547-
secondary_separator,
548-
enable_wildcard,
549-
color_aug,
550-
flip_aug,
551-
face_crop_aug_range,
552-
random_crop,
553-
caption_dropout_rate,
554-
caption_dropout_every_n_epochs,
555-
caption_tag_dropout_rate,
556-
caption_prefix,
557-
caption_suffix,
558-
token_warmup_min,
559-
token_warmup_step,
560-
alpha_mask,
470+
**kwargs,
561471
) -> None:
562472
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
563473

564474
super().__init__(
565475
image_dir,
566-
num_repeats,
567-
shuffle_caption,
568-
caption_separator,
569-
keep_tokens,
570-
keep_tokens_separator,
571-
secondary_separator,
572-
enable_wildcard,
573-
color_aug,
574-
flip_aug,
575-
face_crop_aug_range,
576-
random_crop,
577-
caption_dropout_rate,
578-
caption_dropout_every_n_epochs,
579-
caption_tag_dropout_rate,
580-
caption_prefix,
581-
caption_suffix,
582-
token_warmup_min,
583-
token_warmup_step,
584-
alpha_mask,
476+
**kwargs,
585477
)
586478

587479
self.conditioning_data_dir = conditioning_data_dir

0 commit comments

Comments
 (0)