@@ -68,6 +68,8 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,
6868 self .latents_flipped : torch .Tensor = None
6969 self .latents_npz : str = None
7070 self .latents_npz_flipped : str = None
71+ self .mask : np .ndarray = None
72+ self .mask_flipped : np .ndarray = None
7173
7274
7375class BucketManager ():
@@ -467,9 +469,12 @@ def shuffle_buckets(self):
467469
468470 def load_image (self , image_path ):
469471 image = Image .open (image_path )
470- if not image .mode == "RGB" :
471- image = image .convert ("RGB" )
472+ # if not image.mode == "RGB":
473+ # image = image.convert("RGB")
474+ if not image .mode == "RGBA" :
475+ image = image .convert ("RGBA" )
472476 img = np .array (image , np .uint8 )
477+ # alpha_channel = np.array(image, np.uint8)[:,:,-1]
473478 return img
474479
475480 def trim_and_resize_if_required (self , image , reso , resized_size ):
@@ -508,16 +513,19 @@ def cache_latents(self, vae):
508513
509514 image = self .load_image (info .absolute_path )
510515 image = self .trim_and_resize_if_required (image , info .bucket_reso , info .resized_size )
511-
516+ mask = image [:,:,- 1 ] #grab alpha channel
517+ image = image [:,:,:3 ] #drop alpha channel
512518 img_tensor = self .image_transforms (image )
513519 img_tensor = img_tensor .unsqueeze (0 ).to (device = vae .device , dtype = vae .dtype )
514520 info .latents = vae .encode (img_tensor ).latent_dist .sample ().squeeze (0 ).to ("cpu" )
521+ info .mask = mask / 255
515522
516523 if self .flip_aug :
517524 image = image [:, ::- 1 ].copy () # cannot convert to Tensor without copy
518525 img_tensor = self .image_transforms (image )
519526 img_tensor = img_tensor .unsqueeze (0 ).to (device = vae .device , dtype = vae .dtype )
520527 info .latents_flipped = vae .encode (img_tensor ).latent_dist .sample ().squeeze (0 ).to ("cpu" )
528+ info .mask_flipped = mask [::- 1 ]/ 255
521529
522530 def get_image_size (self , image_path ):
523531 image = Image .open (image_path )
@@ -606,14 +614,17 @@ def __getitem__(self, index):
606614 input_ids_list = []
607615 latents_list = []
608616 images = []
617+ masks = []
609618
610619 for image_key in bucket [image_index :image_index + bucket_batch_size ]:
611620 image_info = self .image_data [image_key ]
612621 loss_weights .append (self .prior_loss_weight if image_info .is_reg else 1.0 )
613622
614623 # image/latentsを処理する
615624 if image_info .latents is not None :
616- latents = image_info .latents if not self .flip_aug or random .random () < .5 else image_info .latents_flipped
625+ rand_flip = random .random ()
626+ latents = image_info .latents if not self .flip_aug or rand_flip < .5 else image_info .latents_flipped
627+ mask = image_info .mask if not self .flip_aug or rand_flip < .5 else image_info .mask_flipped
617628 image = None
618629 elif image_info .latents_npz is not None :
619630 latents = self .load_latents_from_npz (image_info , self .flip_aug and random .random () >= .5 )
@@ -622,6 +633,8 @@ def __getitem__(self, index):
622633 else :
623634 # 画像を読み込み、必要ならcropする
624635 img , face_cx , face_cy , face_w , face_h = self .load_image_with_face_info (image_info .absolute_path )
636+ mask = img [:,:,- 1 ] #grab alpha channel
637+ img = img [:,:,:3 ] #drop alpha channel
625638 im_h , im_w = img .shape [0 :2 ]
626639
627640 if self .enable_bucket :
@@ -647,7 +660,8 @@ def __getitem__(self, index):
647660
648661 latents = None
649662 image = self .image_transforms (img ) # -1.0~1.0のtorch.Tensorになる
650-
663+ mask = (self .image_transforms (mask )+ 1 )* .5
664+ masks .append (torch .tensor (mask ))
651665 images .append (image )
652666 latents_list .append (latents )
653667
@@ -672,7 +686,7 @@ def __getitem__(self, index):
672686 else :
673687 images = None
674688 example ['images' ] = images
675-
689+ example [ 'masks' ] = torch . stack ( masks ) if masks [ 0 ] is not None else None
676690 example ['latents' ] = torch .stack (latents_list ) if latents_list [0 ] is not None else None
677691
678692 if self .debug_dataset :
@@ -1494,6 +1508,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
14941508 help = "steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します" )
14951509 parser .add_argument ("--bucket_no_upscale" , action = "store_true" ,
14961510 help = "make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" )
1511+ parser .add_argument ("--masked_loss" , action = "store_true" ,
1512+ help = "Enable Masked Loss from Alpha Channel" )
14971513
14981514 if support_caption_dropout :
14991515 # Textual Inversion はcaptionのdropoutをsupportしない
@@ -2059,4 +2075,4 @@ def __getitem__(self, idx):
20592075 return (tensor_pil , img_path )
20602076
20612077
2062- # endregion
2078+ # endregion
0 commit comments