@@ -409,6 +409,7 @@ def __init__(
409409
410410 self .alpha_mask = alpha_mask
411411
412+
412413class DreamBoothSubset (BaseSubset ):
413414 def __init__ (
414415 self ,
@@ -417,13 +418,47 @@ def __init__(
417418 class_tokens : Optional [str ],
418419 caption_extension : str ,
419420 cache_info : bool ,
420- ** kwargs ,
421+ num_repeats ,
422+ shuffle_caption ,
423+ caption_separator : str ,
424+ keep_tokens ,
425+ keep_tokens_separator ,
426+ secondary_separator ,
427+ enable_wildcard ,
428+ color_aug ,
429+ flip_aug ,
430+ face_crop_aug_range ,
431+ random_crop ,
432+ caption_dropout_rate ,
433+ caption_dropout_every_n_epochs ,
434+ caption_tag_dropout_rate ,
435+ caption_prefix ,
436+ caption_suffix ,
437+ token_warmup_min ,
438+ token_warmup_step ,
421439 ) -> None :
422440 assert image_dir is not None , "image_dir must be specified / image_dirは指定が必須です"
423441
424442 super ().__init__ (
425443 image_dir ,
426- ** kwargs ,
444+ num_repeats ,
445+ shuffle_caption ,
446+ caption_separator ,
447+ keep_tokens ,
448+ keep_tokens_separator ,
449+ secondary_separator ,
450+ enable_wildcard ,
451+ color_aug ,
452+ flip_aug ,
453+ face_crop_aug_range ,
454+ random_crop ,
455+ caption_dropout_rate ,
456+ caption_dropout_every_n_epochs ,
457+ caption_tag_dropout_rate ,
458+ caption_prefix ,
459+ caption_suffix ,
460+ token_warmup_min ,
461+ token_warmup_step ,
427462 )
428463
429464 self .is_reg = is_reg
@@ -444,13 +479,47 @@ def __init__(
444479 self ,
445480 image_dir ,
446481 metadata_file : str ,
447- ** kwargs ,
482+ num_repeats ,
483+ shuffle_caption ,
484+ caption_separator ,
485+ keep_tokens ,
486+ keep_tokens_separator ,
487+ secondary_separator ,
488+ enable_wildcard ,
489+ color_aug ,
490+ flip_aug ,
491+ face_crop_aug_range ,
492+ random_crop ,
493+ caption_dropout_rate ,
494+ caption_dropout_every_n_epochs ,
495+ caption_tag_dropout_rate ,
496+ caption_prefix ,
497+ caption_suffix ,
498+ token_warmup_min ,
499+ token_warmup_step ,
448500 ) -> None :
449501 assert metadata_file is not None , "metadata_file must be specified / metadata_fileは指定が必須です"
450502
451503 super ().__init__ (
452504 image_dir ,
453- ** kwargs ,
505+ num_repeats ,
506+ shuffle_caption ,
507+ caption_separator ,
508+ keep_tokens ,
509+ keep_tokens_separator ,
510+ secondary_separator ,
511+ enable_wildcard ,
512+ color_aug ,
513+ flip_aug ,
514+ face_crop_aug_range ,
515+ random_crop ,
516+ caption_dropout_rate ,
517+ caption_dropout_every_n_epochs ,
518+ caption_tag_dropout_rate ,
519+ caption_prefix ,
520+ caption_suffix ,
521+ token_warmup_min ,
522+ token_warmup_step ,
454523 )
455524
456525 self .metadata_file = metadata_file
@@ -468,13 +537,47 @@ def __init__(
468537 conditioning_data_dir : str ,
469538 caption_extension : str ,
470539 cache_info : bool ,
471- ** kwargs ,
540+ num_repeats ,
541+ shuffle_caption ,
542+ caption_separator ,
543+ keep_tokens ,
544+ keep_tokens_separator ,
545+ secondary_separator ,
546+ enable_wildcard ,
547+ color_aug ,
548+ flip_aug ,
549+ face_crop_aug_range ,
550+ random_crop ,
551+ caption_dropout_rate ,
552+ caption_dropout_every_n_epochs ,
553+ caption_tag_dropout_rate ,
554+ caption_prefix ,
555+ caption_suffix ,
556+ token_warmup_min ,
557+ token_warmup_step ,
472558 ) -> None :
473559 assert image_dir is not None , "image_dir must be specified / image_dirは指定が必須です"
474560
475561 super ().__init__ (
476562 image_dir ,
477- ** kwargs ,
563+ num_repeats ,
564+ shuffle_caption ,
565+ caption_separator ,
566+ keep_tokens ,
567+ keep_tokens_separator ,
568+ secondary_separator ,
569+ enable_wildcard ,
570+ color_aug ,
571+ flip_aug ,
572+ face_crop_aug_range ,
573+ random_crop ,
574+ caption_dropout_rate ,
575+ caption_dropout_every_n_epochs ,
576+ caption_tag_dropout_rate ,
577+ caption_prefix ,
578+ caption_suffix ,
579+ token_warmup_min ,
580+ token_warmup_step ,
478581 )
479582
480583 self .conditioning_data_dir = conditioning_data_dir
@@ -1100,10 +1203,12 @@ def __getitem__(self, index):
11001203 else :
11011204 latents = image_info .latents_flipped
11021205 alpha_mask = image_info .alpha_mask_flipped
1103-
1206+
11041207 image = None
11051208 elif image_info .latents_npz is not None : # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
1106- latents , original_size , crop_ltrb , flipped_latents , alpha_mask , flipped_alpha_mask = load_latents_from_disk (image_info .latents_npz )
1209+ latents , original_size , crop_ltrb , flipped_latents , alpha_mask , flipped_alpha_mask = load_latents_from_disk (
1210+ image_info .latents_npz
1211+ )
11071212 if flipped :
11081213 latents = flipped_latents
11091214 alpha_mask = flipped_alpha_mask
@@ -1116,7 +1221,9 @@ def __getitem__(self, index):
11161221 image = None
11171222 else :
11181223 # 画像を読み込み、必要ならcropする
1119- img , face_cx , face_cy , face_w , face_h = self .load_image_with_face_info (subset , image_info .absolute_path , subset .alpha_mask )
1224+ img , face_cx , face_cy , face_w , face_h = self .load_image_with_face_info (
1225+ subset , image_info .absolute_path , subset .alpha_mask
1226+ )
11201227 im_h , im_w = img .shape [0 :2 ]
11211228
11221229 if self .enable_bucket :
@@ -1157,7 +1264,7 @@ def __getitem__(self, index):
11571264 if img .shape [2 ] == 4 :
11581265 alpha_mask = img [:, :, 3 ] # [W,H]
11591266 else :
1160- alpha_mask = np .full ((im_w , im_h ), 255 , dtype = np .uint8 ) # [W,H]
1267+ alpha_mask = np .full ((im_w , im_h ), 255 , dtype = np .uint8 ) # [W,H]
11611268 alpha_mask = transforms .ToTensor ()(alpha_mask )
11621269 else :
11631270 alpha_mask = None
@@ -2070,7 +2177,14 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
20702177# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
20712178def load_latents_from_disk (
20722179 npz_path ,
2073- ) -> Tuple [Optional [torch .Tensor ], Optional [List [int ]], Optional [List [int ]], Optional [torch .Tensor ], Optional [torch .Tensor ], Optional [torch .Tensor ]]:
2180+ ) -> Tuple [
2181+ Optional [torch .Tensor ],
2182+ Optional [List [int ]],
2183+ Optional [List [int ]],
2184+ Optional [torch .Tensor ],
2185+ Optional [torch .Tensor ],
2186+ Optional [torch .Tensor ],
2187+ ]:
20742188 npz = np .load (npz_path )
20752189 if "latents" not in npz :
20762190 raise ValueError (f"error: npz is old format. please re-generate { npz_path } " )
@@ -2084,7 +2198,9 @@ def load_latents_from_disk(
20842198 return latents , original_size , crop_ltrb , flipped_latents , alpha_mask , flipped_alpha_mask
20852199
20862200
2087- def save_latents_to_disk (npz_path , latents_tensor , original_size , crop_ltrb , flipped_latents_tensor = None , alpha_mask = None , flipped_alpha_mask = None ):
2201+ def save_latents_to_disk (
2202+ npz_path , latents_tensor , original_size , crop_ltrb , flipped_latents_tensor = None , alpha_mask = None , flipped_alpha_mask = None
2203+ ):
20882204 kwargs = {}
20892205 if flipped_latents_tensor is not None :
20902206 kwargs ["latents_flipped" ] = flipped_latents_tensor .float ().cpu ().numpy ()
@@ -2344,10 +2460,10 @@ def cache_batch_latents(
23442460 image , original_size , crop_ltrb = trim_and_resize_if_required (random_crop , image , info .bucket_reso , info .resized_size )
23452461 if info .use_alpha_mask :
23462462 if image .shape [2 ] == 4 :
2347- alpha_mask = image [:, :, 3 ] # [W,H]
2463+ alpha_mask = image [:, :, 3 ] # [W,H]
23482464 image = image [:, :, :3 ]
23492465 else :
2350- alpha_mask = np .full_like (image [:, :, 0 ], 255 , dtype = np .uint8 ) # [W,H]
2466+ alpha_mask = np .full_like (image [:, :, 0 ], 255 , dtype = np .uint8 ) # [W,H]
23512467 alpha_masks .append (transforms .ToTensor ()(alpha_mask ))
23522468 image = IMAGE_TRANSFORMS (image )
23532469 images .append (image )
@@ -2377,13 +2493,23 @@ def cache_batch_latents(
23772493 flipped_latents = [None ] * len (latents )
23782494 flipped_alpha_masks = [None ] * len (image_infos )
23792495
2380- for info , latent , flipped_latent , alpha_mask , flipped_alpha_mask in zip (image_infos , latents , flipped_latents , alpha_masks , flipped_alpha_masks ):
2496+ for info , latent , flipped_latent , alpha_mask , flipped_alpha_mask in zip (
2497+ image_infos , latents , flipped_latents , alpha_masks , flipped_alpha_masks
2498+ ):
23812499 # check NaN
23822500 if torch .isnan (latents ).any () or (flipped_latent is not None and torch .isnan (flipped_latent ).any ()):
23832501 raise RuntimeError (f"NaN detected in latents: { info .absolute_path } " )
23842502
23852503 if cache_to_disk :
2386- save_latents_to_disk (info .latents_npz , latent , info .latents_original_size , info .latents_crop_ltrb , flipped_latent , alpha_mask , flipped_alpha_mask )
2504+ save_latents_to_disk (
2505+ info .latents_npz ,
2506+ latent ,
2507+ info .latents_original_size ,
2508+ info .latents_crop_ltrb ,
2509+ flipped_latent ,
2510+ alpha_mask ,
2511+ flipped_alpha_mask ,
2512+ )
23872513 else :
23882514 info .latents = latent
23892515 if flip_aug :
0 commit comments