@@ -1265,7 +1265,8 @@ def __getitem__(self, index):
12651265 if subset .alpha_mask :
12661266 if img .shape [2 ] == 4 :
12671267 alpha_mask = img [:, :, 3 ] # [H,W]
1268- alpha_mask = transforms .ToTensor ()(alpha_mask ) # 0-255 -> 0-1
1268+ alpha_mask = alpha_mask .astype (np .float32 ) / 255.0 # 0.0~1.0
1269+ alpha_mask = torch .FloatTensor (alpha_mask )
12691270 else :
12701271 alpha_mask = torch .ones ((img .shape [0 ], img .shape [1 ]), dtype = torch .float32 )
12711272 else :
@@ -2211,7 +2212,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
22112212# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
22122213def load_latents_from_disk (
22132214 npz_path ,
2214- ) -> Tuple [Optional [torch . Tensor ], Optional [List [int ]], Optional [List [int ]], Optional [np .ndarray ], Optional [np .ndarray ]]:
2215+ ) -> Tuple [Optional [np . ndarray ], Optional [List [int ]], Optional [List [int ]], Optional [np .ndarray ], Optional [np .ndarray ]]:
22152216 npz = np .load (npz_path )
22162217 if "latents" not in npz :
22172218 raise ValueError (f"error: npz is old format. please re-generate { npz_path } " )
@@ -2229,7 +2230,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli
22292230 if flipped_latents_tensor is not None :
22302231 kwargs ["latents_flipped" ] = flipped_latents_tensor .float ().cpu ().numpy ()
22312232 if alpha_mask is not None :
2232- kwargs ["alpha_mask" ] = alpha_mask # ndarray
2233+ kwargs ["alpha_mask" ] = alpha_mask . float (). cpu (). numpy ()
22332234 np .savez (
22342235 npz_path ,
22352236 latents = latents_tensor .float ().cpu ().numpy (),
@@ -2496,8 +2497,9 @@ def cache_batch_latents(
24962497 if image .shape [2 ] == 4 :
24972498 alpha_mask = image [:, :, 3 ] # [H,W]
24982499 alpha_mask = alpha_mask .astype (np .float32 ) / 255.0
2500+ alpha_mask = torch .FloatTensor (alpha_mask ) # [H,W]
24992501 else :
2500- alpha_mask = np .ones_like (image [:, :, 0 ], dtype = np .float32 )
2502+ alpha_mask = torch .ones_like (image [:, :, 0 ], dtype = torch .float32 ) # [H,W]
25012503 else :
25022504 alpha_mask = None
25032505 alpha_masks .append (alpha_mask )
0 commit comments