Skip to content

Commit

Permalink
fix the bug of random flip
Browse files Browse the repository at this point in the history
  • Loading branch information
MingtaoGuo authored Oct 19, 2022
1 parent 61d72e4 commit 6697e09
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions training/data_loader/dataset_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from PIL import Image

class FaceDataset(Dataset):
def __init__(self, img_path, mask_path, resolution=512, same_prob=0.2):
def __init__(self, img_path, mask_path, resolution=512, same_prob=0.1):
self.resolution = resolution
self.same_prob = same_prob
self.img_path = img_path
Expand All @@ -20,14 +20,22 @@ def __init__(self, img_path, mask_path, resolution=512, same_prob=0.2):
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.RandomHorizontalFlip()
])
self.transforms_mask = transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip()
])

def __len__(self):
return self.length

def __getitem__(self, index):
seed = torch.random.seed()
torch.random.manual_seed(seed)
target = Image.open(self.img_path + "/" + self.files[index]).resize([self.resolution, self.resolution])
target = self.transforms(target)
mask = np.array(Image.open(self.mask_path + "/" + self.files[index]).resize([self.resolution, self.resolution]))
torch.random.manual_seed(seed)
mask = Image.open(self.mask_path + "/" + self.files[index]).resize([self.resolution, self.resolution])
mask = self.transforms_mask(mask)

if np.random.uniform() < self.same_prob:
source = Image.open(self.img_path + "/" + self.files[index]).resize([self.resolution, self.resolution])
Expand All @@ -39,12 +47,9 @@ def __getitem__(self, index):
source = self.transforms(source)
same = torch.tensor(0.)

mask = mask[None]/255.
mask[mask > 0.5] = 1
mask[mask <= 0.5] = 0
mask = torch.from_numpy(np.float32(mask))
mask = (mask > 0.5).float()

return target, source, mask, same
return target, source, mask[0:1], same

# from tqdm import tqdm
# path = "/data1/GMT/Dataset/thumbnails128x128/"
Expand Down

0 comments on commit 6697e09

Please sign in to comment.