-
Notifications
You must be signed in to change notification settings - Fork 0
/
loadData.py
63 lines (47 loc) · 1.81 KB
/
loadData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf
import numpy as np
import os
img_width = 224
img_height = 224
def get_file(file_dir):
images = []
temp = []
for root, sub_folders, files in os.walk(file_dir):
for name in files:
images.append(os.path.join(root, name))
for name in sub_folders:
temp.append(os.path.join(root, name))
labels = []
for one_folder in temp:
n_img = len(os.listdir(one_folder))
letter = one_folder.split('\\')[-1]
if letter == 'cat':
labels = np.append(labels, n_img * [0])
else:
labels = np.append(labels, n_img * [1])
# shuffle
temp = np.array([images, labels])
temp = temp.transpose()
np.random.shuffle(temp)
image_list = list(temp[:, 0])
label_list = list(temp[:, 1])
label_list = [int(float(i)) for i in label_list]
return image_list, label_list
def get_batch(image_list, label_list, img_width, img_height, batch_size, capacity):
image = tf.cast(image_list, tf.string)
label = tf.cast(label_list, tf.int32)
input_queue = tf.train.slice_input_producer([image, label])
label = input_queue[1]
image_contents = tf.read_file(input_queue[0])
image = tf.image.decode_jpeg(image_contents, channels=3)
image = tf.image.resize_image_with_crop_or_pad(image, img_width, img_height)
image = tf.image.per_image_standardization(image) # 将图片标准化
image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, num_threads=64, capacity=capacity)
label_batch = tf.reshape(label_batch, [batch_size])
return image_batch, label_batch
def onehot(labels):
n_sample = len(labels)
n_class = max(labels) + 1
onehot_labels = np.zeros((n_sample, n_class))
onehot_labels[np.arange(n_sample), labels] = 1
return onehot_labels