forked from SamueleBolotta/RIMs-Sequential-MNIST
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
106 lines (74 loc) · 4.04 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import struct
import numpy as np
import gzip
import cv2
def read_idx(filename):
with gzip.open(filename, 'rb') as f:
zero, data_type, dims = struct.unpack('>HBB', f.read(4))
shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
return np.fromstring(f.read(), dtype=np.uint8).reshape(shape)
#if __name__ == '__main__':
# a = read_idx('mnist/train-labels-idx1-ubyte.gz')
# print(a.shape)
class MnistData:
def __init__(self, batch_size, size, k):
#self.train_data = read_idx('/content/drive/My Drive/Recurrent Independent Mechanisms/mnist/train-images-idx3-ubyte.gz')
#self.train_labels = read_idx('/content/drive/My Drive/Recurrent Independent Mechanisms/mnist/train-labels-idx1-ubyte.gz')
#self.val_data = read_idx('/content/drive/My Drive/Recurrent Independent Mechanisms/mnist/t10k-images-idx3-ubyte.gz')
#self.val_labels = read_idx('/content/drive/My Drive/Recurrent Independent Mechanisms/mnist/t10k-labels-idx1-ubyte.gz')
self.train_data = read_idx('mnist/train-images-idx3-ubyte.gz')
self.train_labels = read_idx('mnist/train-labels-idx1-ubyte.gz')
self.val_data = read_idx('mnist/t10k-images-idx3-ubyte.gz')
self.val_labels = read_idx('mnist/t10k-labels-idx1-ubyte.gz')
train_data_ = np.zeros((self.train_data.shape[0], size[0] * size[1]))
val_data_1 = np.zeros((self.val_data.shape[0], (size[0] + 10)* (size[1] + 10)))
val_data_2 = np.zeros((self.val_data.shape[0], (size[0] + 5) * (size[1] + 5)))
val_data_3 = np.zeros((self.val_data.shape[0], (size[0] + 2) * (size[1] + 2)))
for i in range(self.train_data.shape[0]):
img = self.train_data[i, :]
img = cv2.resize(img, size, interpolation = cv2.INTER_NEAREST)
_, img = cv2.threshold(img, 120, 255, cv2.THRESH_BINARY)
img = np.reshape(img, (-1))
train_data_[i, :] = img
for i in range(self.val_data.shape[0]):
img = self.val_data[i, :]
img1 = cv2.resize(img, (size[0] + 10, size[1] + 10), interpolation = cv2.INTER_NEAREST)
_, img1 = cv2.threshold(img1, 120, 255, cv2.THRESH_BINARY)
img1 = np.reshape(img1, (-1))
val_data_1[i, :] = img1
img2 = cv2.resize(img, (size[0] + 5, size[1] + 5), interpolation = cv2.INTER_NEAREST)
_, img2 = cv2.threshold(img2, 120, 255, cv2.THRESH_BINARY)
img2 = np.reshape(img2, (-1))
val_data_2[i, :] = img2
img3 = cv2.resize(img, (size[0] + 2, size[1] + 2), interpolation = cv2.INTER_NEAREST)
_, img3 = cv2.threshold(img3, 120, 255, cv2.THRESH_BINARY)
img3 = np.reshape(img3, (-1))
val_data_3[i, :] = img3
self.train_data = train_data_
self.val_data1 = val_data_1
self.val_data2 = val_data_2
self.val_data3 = val_data_3
del train_data_
self.train_data = np.reshape(self.train_data, (self.train_data.shape[0], self.train_data.shape[1], 1))
self.val_data1 = np.reshape(self.val_data1, (self.val_data1.shape[0], self.val_data1.shape[1], 1))
self.val_data2 = np.reshape(self.val_data2, (self.val_data2.shape[0], self.val_data2.shape[1], 1))
self.val_data3 = np.reshape(self.val_data3, (self.val_data3.shape[0], self.val_data3.shape[1], 1))
self.train_data = [self.train_data[i:i + batch_size] for i in range(0, self.train_data.shape[0], batch_size)]
self.val_data1 = [self.val_data1[i:i + 512] for i in range(0, self.val_data1.shape[0], 512)]
self.val_data2 = [self.val_data2[i:i + 512] for i in range(0, self.val_data2.shape[0], 512)]
self.val_data3 = [self.val_data3[i:i + 512] for i in range(0, self.val_data3.shape[0], 512)]
self.train_labels = [self.train_labels[i:i + batch_size] for i in range(0, self.train_labels.shape[0], batch_size)]
self.val_labels = [self.val_labels[i:i + 512] for i in range(0, self.val_labels.shape[0], 512)]
def train_len(self):
return len(self.train_labels)
def val_len(self):
return len(self.val_labels)
def train_get(self, i):
return self.train_data[i], self.train_labels[i]
def val_get1(self, i):
return self.val_data1[i], self.val_labels[i]
def val_get2(self, i):
return self.val_data2[i], self.val_labels[i]
def val_get3(self, i):
return self.val_data3[i], self.val_labels[i]