Skip to content

Commit 21162c7

Browse files
committed
Fix alpha_mask transformation
1 parent aee2525 commit 21162c7

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

library/train_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1169,7 +1169,7 @@ def __getitem__(self, index):
11691169
alpha_mask = cv2.resize(
11701170
alpha_mask, target_size, interpolation=cv2.INTER_AREA
11711171
)
1172-
alpha_mask = self.image_transforms(alpha_mask)
1172+
alpha_mask = transforms.ToTensor()(alpha_mask)
11731173
alpha_mask_list.append(alpha_mask)
11741174

11751175
if not flipped:
@@ -2338,6 +2338,7 @@ def cache_batch_latents(
23382338
image = image[:, :, :3]
23392339
else:
23402340
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
2341+
alpha_mask = transforms.ToTensor()(alpha_mask)
23412342
alpha_masks.append(alpha_mask)
23422343
else:
23432344
alpha_masks.append(None)

0 commit comments

Comments
 (0)