-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataset.py
109 lines (81 loc) · 2.78 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import os
import torch
from PIL import Image
import utils
def sample_from_data(args, device, data_loader):
"""Sample real images and labels from data_loader.
Args:
args (argparse object)
device (torch.device)
data_loader (DataLoader)
Returns:
real, y
"""
real, y = next(data_loader)
real, y = real.to(device), y.to(device)
return real, y
def sample_from_gen(args, device, num_classes, gen):
"""Sample fake images and labels from generator.
Args:
args (argparse object)
device (torch.device)
num_classes (int): for pseudo_y
gen (nn.Module)
Returns:
fake, pseudo_y, z
"""
z = utils.sample_z(
args.batch_size, args.gen_dim_z, device, args.gen_distribution
)
pseudo_y = utils.sample_pseudo_labels(
num_classes, args.batch_size, device
)
fake = gen(z, pseudo_y)
return fake, pseudo_y, z
class FaceDataset(torch.utils.data.Dataset):
def __init__(self, args, root='', transform=None):
super(FaceDataset, self).__init__()
self.root = root
self.transform = transform
self.images = []
self.path = self.root
num_classes = len([lists for lists in os.listdir(
self.path) if os.path.isdir(os.path.join(self.path, lists))])
for idx in range(num_classes):
class_path = os.path.join(self.path, str(idx))
for _, _, files in os.walk(class_path):
for img_name in files:
image_path = os.path.join(class_path, img_name)
image = Image.open(image_path)
if args.data_name == 'facescrub':
if image.size != (64, 64):
image = image.resize((64, 64), Image.ANTIALIAS)
self.images.append((image, idx))
def __getitem__(self, index):
img, label = self.images[index]
if self.transform != None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.images)
# Copied from https://github.com/naoto0804/pytorch-AdaIN/blob/master/sampler.py#L5-L15
def InfiniteSampler(n):
# i = 0
i = n - 1
order = np.random.permutation(n)
while True:
yield order[i]
i += 1
if i >= n:
np.random.seed()
order = np.random.permutation(n)
i = 0
# Copied from https://github.com/naoto0804/pytorch-AdaIN/blob/master/sampler.py#L18-L26
class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):
def __init__(self, data_source):
self.num_samples = len(data_source)
def __iter__(self):
return iter(InfiniteSampler(self.num_samples))
def __len__(self):
return 2 ** 31