-
Notifications
You must be signed in to change notification settings - Fork 6
/
batch_transformers.py
135 lines (107 loc) · 5.3 KB
/
batch_transformers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import numpy as np
from torchvision import transforms
import torch
from PIL import Image
import collections
import random
RANDOM_RESOLUTIONS = [512, 768, 1024, 1280, 1536]
def rand_crop(data,size):
width1 = random.randint(0, data[0].size[0] - size)
height1 = random.randint(0, data[0].size[1] - size)
width2 = width1 + size
height2 = height1 + size
i=0
while i < len(data):
data[i] = data[i].crop((width1, height1, width2, height2))
i += 1
#label = label.crop((width1, height1, width2, height2))
return data
class BatchRandomResolution(object):
def __init__(self, size=256, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) or (size is None)
self.size = size
self.interpolation = interpolation
def __call__(self, imgs):
#imgs = [img[0],img[1]]
# if self.size is None:
# h, w = imgs[0].size
# max_idx = 0
# for i in range(len(RANDOM_RESOLUTIONS)):
# if h > RANDOM_RESOLUTIONS[i] and w > RANDOM_RESOLUTIONS[i]:
# max_idx += 1
# idx = np.random.randint(max_idx)
# self.size = RANDOM_RESOLUTIONS[idx]
# imgs.append(gt)
# imgs2 = rand_crop(imgs, self.size)
# return imgs2
return [transforms.Resize([self.size, self.size])(img) for img in imgs]
class BatchRandomResolution_test(object):
def __init__(self, size=None, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) or (size is None)
self.size = size
self.interpolation = interpolation
def __call__(self, imgs):
# if self.size is None:
# h, w = imgs[0].size
# max_idx = 0
# for i in range(len(RANDOM_RESOLUTIONS)):
# if h > RANDOM_RESOLUTIONS[i] and w > RANDOM_RESOLUTIONS[i]:
# max_idx += 1
# idx = np.random.randint(max_idx)
# self.size = RANDOM_RESOLUTIONS[idx]
# #return [transforms.RandomCrop([self.size,self.size])(img) for img in imgs]
return [transforms.Resize(self.size)(img) for img in imgs]
class BatchTestResolution(object):
def __init__(self, size=None, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) or (size is None)
self.size = size
self.interpolation = interpolation
def __call__(self, imgs):
# h, w = imgs[0].size
# if h > self.size and w > self.size:
# #return [transforms.RandomCrop(self.size, self.size)(img) for img in imgs]
# return [transforms.Resize([size,size])(img) for img in imgs]
# else:
# return imgs
return [transforms.Resize([144,144])(img) for img in imgs]
class BatchTestResolution(object):
def __init__(self, size=None, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) or (size is None)
self.size = size
self.interpolation = interpolation
def __call__(self, imgs):
# h, w = imgs[0].size
# if h > self.size and w > self.size:
# #return [transforms.RandomCrop(self.size, self.size)(img) for img in imgs]
# return [transforms.Resize([size,size])(img) for img in imgs]
# else:
# return imgs
return [transforms.Resize([self.size,self.size])(img) for img in imgs]
class BatchToTensor(object):
def __call__(self, imgs):
return [transforms.ToTensor()(img) for img in imgs]
class BatchRGBToGray(object):
def __call__(self, imgs):
return [img[0, :, :] * 0.299 + img[1, :, :] * 0.587 + img[2:, :, :] * 0.114 for img in imgs]
class RGBToGray(object):
def __call__(self, img):
x = img[:,0, :, :] * 0.299 + img[:,1, :, :] * 0.587 + img[:,2, :, :] * 0.114
return torch.stack((x,x,x),dim=1)
class BatchRGBToYCbCr(object):
def __call__(self, imgs):
return [torch.stack((0. / 256. + img[0, :, :] * 0.299000 + img[1, :, :] * 0.587000 + img[2, :, :] * 0.114000,
128. / 256. - img[0, :, :] * 0.168736 - img[1, :, :] * 0.331264 + img[2, :, :] * 0.500000,
128. / 256. + img[0, :, :] * 0.500000 - img[1, :, :] * 0.418688 - img[2, :, :] * 0.081312),
dim=0) for img in imgs]
class YCbCrToRGB(object):
def __call__(self, img):
return torch.stack((img[:, 0, :, :] + (img[:, 2, :, :] - 128 / 256.) * 1.402,
img[:, 0, :, :] - (img[:, 1, :, :] - 128 / 256.) * 0.344136 - (img[:, 2, :, :] - 128 / 256.) * 0.714136,
img[:, 0, :, :] + (img[:, 1, :, :] - 128 / 256.) * 1.772),
dim=1)
class RGBToYCbCr(object):
def __call__(self, img):
return torch.stack((0. / 256. + img[:, 0, :, :] * 0.299000 + img[:, 1, :, :] * 0.587000 + img[:, 2, :, :] * 0.114000,
128. / 256. - img[:, 0, :, :] * 0.168736 - img[:, 1, :, :] * 0.331264 + img[:, 2, :, :] * 0.500000,
128. / 256. + img[:, 0, :, :] * 0.500000 - img[:, 1, :, :] * 0.418688 - img[:, 2, :, :] * 0.081312),
dim=1)