-
Notifications
You must be signed in to change notification settings - Fork 4
/
get_dataset.py
98 lines (75 loc) · 3.27 KB
/
get_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import json
import numpy as np
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
from dataset_helpers import crop_from_center, get_nine_crops
class GetDataset(Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self, file_paths, labels, transform=None):
'Initialization'
self.imgs = [(img_path, label) for img_path, label in zip(file_paths, labels)]
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
'Denotes the total number of samples'
return len(self.file_paths)
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
file_path = self.file_paths[index]
label = self.labels[index]
pil_image = Image.open(file_path)
# Check if image has only single channel. If True, then swap with 0th image
# Assumption 0th image has got 3 number of channels
if len(pil_image.getbands()) != 3:
file_path = self.file_paths[0]
label = self.labels[0]
pil_image = Image.open(file_path)
# Convert image to torch tensor
tr_image = self.transform(pil_image)
return tr_image, label
class GetJigsawPuzzleDataset(Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self, file_paths, avail_permuts_file_path, range_permut_indices=None, transform=None):
'Initialization'
self.file_paths = file_paths
self.transform = transform
self.permuts_avail = np.load(avail_permuts_file_path)
self.range_permut_indices = range_permut_indices
def __len__(self):
'Denotes the total number of samples'
return len(self.file_paths)
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
file_path = self.file_paths[index]
pil_image = Image.open(file_path)
# Check if image has only single channel. If True, then swap with 0th image
# Assumption 0th image has got 3 number of channels
if len(pil_image.getbands()) != 3:
file_path = self.file_paths[0]
pil_image = Image.open(file_path)
# Convert image to torch tensor
pil_image = pil_image.resize((256, 256))
pil_image = crop_from_center(pil_image, 225, 225)
# Get nine crops for the image
nine_crops = get_nine_crops(pil_image)
# Permut the 9 patches obtained from the image
if self.range_permut_indices:
permut_ind = random.randint(self.range_permut_indices[0], self.range_permut_indices[1])
else:
permut_ind = random.randint(0, len(self.permuts_avail) - 1)
permutation_config = self.permuts_avail[permut_ind]
permuted_patches_arr = [None] * 9
for crop_new_pos, crop in zip(permutation_config, nine_crops):
permuted_patches_arr[crop_new_pos] = crop
# Apply data transforms
# TODO: Remove hard coded values from here
tensor_patches = torch.zeros(9, 3, 64, 64)
for ind, jigsaw_patch in enumerate(permuted_patches_arr):
jigsaw_patch_tr = self.transform(jigsaw_patch)
tensor_patches[ind] = jigsaw_patch_tr
return tensor_patches, permut_ind