-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
100 lines (83 loc) · 2.37 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from glob import glob
from PIL import Image
from typing import *
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
"""
path : dataset/
├── images
│ ├─ class 1
│ ├─ img1.jpg
│ ├─ ...
│ ├─ class 2
│ ├─ img1.jpg
│ ├─ ...
│ ├─ class 3
│ ├─ img1.jpg
│ ├─ ...
│ ├─ class 4
│ ├─ img1.jpg
│ ├─ ...
│ ├─ ...
│ ├─ ...
│ ├─ ...
"""
# image padding to prevent distortion of image
class Padding(object):
def __init__(self, fill):
self.fill = fill
def __call__(self, src):
w, h = src.size
if w == h:
return src
elif w > h:
out = Image.new(src.mode, (w, w), self.fill)
out.paste(src, (0, (w - h) // 2))
return out
else:
out = Image.new(src.mode, (h, h), self.fill)
out.paste(src, ((h - w) // 2, 0))
return out
def load_dataloader(
path: str,
normalization: bool=False,
img_size: int = 224,
fill_color: Tuple[int, int, int]=(0, 0, 0),
subset: str = 'train',
num_workers: int=8,
batch_size: int=32,
shuffle: bool=True,
drop_last: bool = True,
):
assert subset in ('train', 'valid', 'test')
data_path = path + subset
if subset == 'train':
augmentation = [
Padding(fill=fill_color),
transforms.Resize((img_size, img_size)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=(-20, 20)),
transforms.ToTensor(),
]
else:
augmentation = [
Padding(fill=fill_color),
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
]
if normalization:
augmentation.append(
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
augmentation = transforms.Compose(augmentation)
images = ImageFolder(data_path, transform=augmentation, target_transform=None)
data_loader = DataLoader(
images,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=drop_last,
)
return data_loader