Skip to content

Commit aee2525

Browse files
committed
Fix alpha mask initialization
1 parent d6484b1 commit aee2525

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

library/train_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,7 +1144,7 @@ def __getitem__(self, index):
11441144
if img.shape[2] == 4:
11451145
alpha_mask = img[:, :, 3] # [W,H]
11461146
else:
1147-
alpha_mask = np.ones_like(img[:, :, 0]) # [W,H]
1147+
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
11481148
img = img[:, :, :3] # remove alpha channel
11491149

11501150
# augmentation
@@ -2337,7 +2337,7 @@ def cache_batch_latents(
23372337
alpha_mask = image[:, :, 3] # [W,H]
23382338
image = image[:, :, :3]
23392339
else:
2340-
alpha_mask = np.ones_like(image[:, :, 0]) # [W,H]
2340+
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
23412341
alpha_masks.append(alpha_mask)
23422342
else:
23432343
alpha_masks.append(None)

0 commit comments

Comments
 (0)