-
Notifications
You must be signed in to change notification settings - Fork 394
/
Batch.py
63 lines (55 loc) · 2.01 KB
/
Batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
class BatchGenerator(object):
""" Construct a Data generator. The input X, y should be ndarray or list like type.
Example:
Data_train = BatchGenerator(X=X_train_all, y=y_train_all, shuffle=False)
Data_test = BatchGenerator(X=X_test_all, y=y_test_all, shuffle=False)
X = Data_train.X
y = Data_train.y
or:
X_batch, y_batch = Data_train.next_batch(batch_size)
"""
def __init__(self, X, y, shuffle=False):
if type(X) != np.ndarray:
X = np.asarray(X)
if type(y) != np.ndarray:
y = np.asarray(y)
self._X = X
self._y = y
self._epochs_completed = 0
self._index_in_epoch = 0
self._number_examples = self._X.shape[0]
self._shuffle = shuffle
if self._shuffle:
new_index = np.random.permutation(self._number_examples)
self._X = self._X[new_index]
self._y = self._y[new_index]
@property
def x(self):
return self._X
@property
def y(self):
return self._y
@property
def num_examples(self):
return self._number_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size):
""" Return the next 'batch_size' examples from this data set."""
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._number_examples:
# finished epoch
self._epochs_completed += 1
# Shuffle the data
if self._shuffle:
new_index = np.random.permutation(self._number_examples)
self._X = self._X[new_index]
self._y = self._y[new_index]
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._number_examples
end = self._index_in_epoch
return self._X[start:end], self._y[start:end]