Skip to content

Commit e3eb5d3

Browse files
committed
mp load images
DEMO_images_load_order_mp_cv2.py
1 parent c4fc6a1 commit e3eb5d3

8 files changed

+579
-73
lines changed

Demo/DEMO_images_load_order_mp_cv2.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
import multiprocessing as mp
3+
4+
import cv2
5+
import numpy as np
6+
7+
'''
8+
2018-07-05 Yonv1943 show file images, via multiprocessing
9+
2018-09-04 use multiprocessing for loading images
10+
2018-09-05 add simplify
11+
'''
12+
13+
14+
def img_load(queue, queue_idx__img_paths):
15+
while True:
16+
idx, img_path = queue_idx__img_paths.get()
17+
img = cv2.imread(img_path) # Disk IO
18+
queue.put((img, idx, img_path))
19+
20+
21+
# def img_show_simplify(queue, window_name=''):
22+
# cv2.namedWindow(window_name, cv2.WINDOW_KEEPRATIO)
23+
#
24+
# while True:
25+
# img, idx, img_path = queue.get()
26+
# cv2.imshow(window_name, img)
27+
# cv2.waitKey(1)
28+
29+
30+
def img_show(queue, window_name=''): # check images and keep order
31+
cv2.namedWindow(window_name, cv2.WINDOW_KEEPRATIO)
32+
33+
import bisect
34+
idx_previous = -1
35+
idxs = list()
36+
queue_gets = list()
37+
while True:
38+
queue_get = queue.get()
39+
idx = queue_get[1]
40+
insert = bisect.bisect(idxs, idx) # keep order
41+
idxs.insert(insert, idx)
42+
queue_gets.insert(insert, queue_get)
43+
44+
# print(idx_previous, idxs)
45+
while idxs and idxs[0] == idx_previous + 1:
46+
idx_previous = idxs.pop(0)
47+
img, idx, img_path = queue_gets.pop(0)
48+
if not isinstance(img, np.ndarray): # check images
49+
os.remove(img_path)
50+
print("| Remove no image:", idx, img_path)
51+
elif not (img[-4:, -4:] - 128).any(): # download incomplete
52+
os.remove(img_path)
53+
print("| Remove incomplete image:", idx, img_path)
54+
else:
55+
cv2.imshow(window_name, img)
56+
cv2.waitKey(1)
57+
58+
59+
def run():
60+
src_path = 'F:/url_get_image/ftp.nnvl.noaa.gov_GER_2018'
61+
img_paths = [os.path.join(src_path, f) for f in os.listdir(src_path) if f[-4:] == '.jpg']
62+
print("|Directory perpare to load:", src_path)
63+
print("|Number of images:", len(img_paths), img_paths[0])
64+
65+
mp.set_start_method('spawn')
66+
67+
queue_img = mp.Queue(8)
68+
queue_idx__img_path = mp.Queue(len(img_paths))
69+
[queue_idx__img_path.put(idx__img_path) for idx__img_path in enumerate(img_paths)]
70+
71+
processes = list()
72+
processes.append(mp.Process(target=img_show, args=(queue_img,)), )
73+
processes.extend([mp.Process(target=img_load, args=(queue_img, queue_idx__img_path))
74+
for _ in range(3)])
75+
76+
[setattr(process, "daemon", True) for process in processes]
77+
[process.start() for process in processes]
78+
[process.join() for process in processes]
79+
80+
81+
if __name__ == '__main__':
82+
run()

Demo/DEMO_show_images_mp_cv2.py renamed to Demo/DEMO_images_show_mp_cv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def run():
3535
print(len(img_paths), img_paths[0])
3636

3737
mp.set_start_method('spawn')
38-
queue_img = mp.Queue(4)
38+
queue_img = mp.Queue(8)
3939

4040
processes = [
4141
mp.Process(target=queue_img_put, args=(queue_img, img_paths)),
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import timeit
2+
3+
loop_times = 2 ** 16
4+
5+
for script in [
6+
"[0 for _ in range(n)]",
7+
"[i for i in range(n)]",
8+
"np.zeros(n, np.int)",
9+
"np.arange(n)",
10+
]:
11+
print(timeit.repeat(stmt=script, setup="import numpy as np;n = 2 ** 9;",
12+
repeat=2, number=loop_times))
13+
14+
"""
15+
[2.0564617830023235, 1.795643180586791]
16+
[1.6117533459143214, 1.5369785699837797]
17+
[0.13760958652508481, 0.11909945438194303]
18+
[0.12012791384374921, 0.12155749387735248]
19+
"""

Demo/Plan/TUTO_mnist_1layers.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import sys
2+
import time
3+
import shutil
4+
import numpy as np
5+
6+
import os
7+
import cv2
8+
9+
os.environ["PATH"] += ";D:/CUDA/v8.0/bin;"
10+
import tensorflow as tf
11+
12+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '1' # low the warning level
13+
14+
"""
15+
Source: Aymeric Damien: cnn_mnist.py
16+
https://github.com/aymericdamien/TensorFlow-Examples/
17+
Modify: Yonv1943 2018-07-13 13:30:40
18+
19+
2018-07-13 Stable, complete
20+
2018-07-13 Add TensorBoard GRAPHS HISTOGRAM
21+
2018-07-14 Add two dropout layer, lift accuracy to 99.8%>= in test_set
22+
2018-07-14 Remove accuracy from TensorFLow Calculate
23+
2018-07-14 Change to three layers network, not softmax
24+
2018-07-14 Change to one layer network
25+
"""
26+
27+
28+
class Global(object): # Global Variables
29+
batch_size = 500
30+
batch_epoch = 55000 // batch_size # mnist train data is 55000
31+
train_epoch = 2 ** 5 # accuracy in test_set nearly 90%, 15s, (Intel i3-3110M, GTX 720M)
32+
33+
data_dir = 'MNIST_data'
34+
txt_path = 'tf_training_info.txt'
35+
36+
model_save_dir = 'mnist_model'
37+
model_save_name = 'mnist_model'
38+
model_save_path = os.path.join(model_save_dir, model_save_name)
39+
40+
41+
G = Global()
42+
43+
44+
def get_mnist_data(data_dir='MNIST_data'):
45+
from tensorflow.examples.tutorials.mnist import input_data
46+
mnist = input_data.read_data_sets(data_dir, one_hot=True)
47+
48+
train_image = mnist.train.images
49+
train_label = mnist.train.labels
50+
51+
train_image = train_image[:G.batch_epoch * G.batch_size]
52+
train_label = train_label[:G.batch_epoch * G.batch_size]
53+
54+
test_image = mnist.test.images
55+
test_label = mnist.test.labels
56+
57+
data_para = (train_image, train_label, test_image, test_label)
58+
data_para = [np.array(ary, dtype=np.float32) for ary in data_para]
59+
return data_para
60+
61+
62+
def init_session():
63+
image = tf.placeholder(tf.float32, [None, 784], name='Input') # img: 28x28
64+
label = tf.placeholder(tf.float32, [None, 10], name='Label') # 0~9 == 10 classes
65+
66+
w1 = tf.get_variable(shape=[784, 10], name='Weights1')
67+
b1 = tf.get_variable(shape=[10], name='Bias1')
68+
69+
pred = tf.nn.softmax(tf.matmul(image, w1) + b1)
70+
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred, labels=label)) # high accuracy
71+
optimizer = tf.train.AdamOptimizer().minimize(loss)
72+
73+
sess_para = (image, label, pred, loss, optimizer)
74+
return sess_para
75+
76+
77+
def train_session(sess_para, data_para):
78+
(train_image, train_label, test_image, test_label) = data_para
79+
(image, label, pred, loss, optimizer) = sess_para
80+
81+
shutil.rmtree(G.model_save_dir, ignore_errors=True)
82+
logs = open(G.txt_path, 'a')
83+
sess = tf.Session()
84+
sess.run(tf.global_variables_initializer())
85+
86+
'''train loop init'''
87+
predict, summary, feed_train_label = None, None, None
88+
time0 = time1 = time.time()
89+
print('|Train_epoch: %d |batch: epoch*size" %dx%d' % (G.train_epoch, G.batch_epoch, G.batch_size))
90+
for train_epoch in range(G.train_epoch):
91+
loss_sum = 0.0
92+
for i in range(G.batch_epoch):
93+
j = i * G.batch_size
94+
feed_train_image = train_image[j: j + G.batch_size]
95+
feed_train_label = train_label[j: j + G.batch_size]
96+
97+
feed_dict = {image: feed_train_image, label: feed_train_label}
98+
predict, loss_batch, _ = sess.run([pred, loss, optimizer], feed_dict)
99+
100+
loss_sum += loss_batch
101+
(print(end='='), sys.stdout.flush()) if i % (G.batch_epoch // 16 + 1) == 0 else None
102+
103+
ave_cost = loss_sum / G.batch_epoch
104+
logs.write('%e\n' % ave_cost)
105+
106+
accuracy = np.average(np.equal(np.argmax(predict, 1), np.argmax(feed_train_label, 1)))
107+
108+
time2 = time.time()
109+
print(end="\n|Time: %4.1f|%2d |Loss: %.2e |Inac: %.2e |"
110+
% (time2 - time1, train_epoch, ave_cost, 1 - accuracy))
111+
time1 = time2
112+
print()
113+
print('|Time: %.2f |epoch_batch: %d_%dx%d ' % (time.time() - time0, G.train_epoch, G.batch_epoch, G.batch_size))
114+
115+
'''save'''
116+
os.makedirs(G.model_save_dir)
117+
tf.train.Saver().save(sess, G.model_save_path), print('|model save in:', G.model_save_path)
118+
draw_plot(G.txt_path)
119+
120+
sess.close()
121+
logs.close()
122+
123+
124+
def eval_session(sess_para, data_para):
125+
(train_image, train_label, test_image, test_label) = data_para
126+
(image, label, pred, loss, optimizer) = sess_para
127+
128+
sess = tf.Session()
129+
tf.train.Saver().restore(sess, G.model_save_path)
130+
'''evaluation'''
131+
for print_info, feed_image, feed_label in [
132+
['Train_set', train_image[:len(test_image)], train_label[:len(test_label)]],
133+
['Test_set ', test_image, test_label],
134+
]:
135+
feed_dict = {image: feed_image, label: feed_label}
136+
predicts = pred.eval(feed_dict, session=sess)
137+
accuracy = np.average(np.equal(np.argmax(predicts, 1), np.argmax(feed_label, 1)))
138+
inaccuracy = 1.0 - accuracy
139+
print("|%s |Accuracy: %2.4f%% |Inaccuracy: %.2e" % (print_info, accuracy * 100, inaccuracy))
140+
sess.close()
141+
142+
143+
def real_time_session(sess_para, window_name='cv2_mouse_paint', size=16):
144+
(image, label, pred, loss, optimizer) = sess_para
145+
146+
feed_dict = dict()
147+
feed_dict[image] = np.array([])
148+
149+
sess = tf.Session()
150+
tf.train.Saver().restore(sess, G.model_save_path)
151+
152+
def paint_brush(event, x, y, flags, param): # mouse callback function
153+
global ix, iy, drawing
154+
155+
if event == cv2.EVENT_LBUTTONDOWN:
156+
ix, iy = x, y
157+
drawing = True
158+
elif event == cv2.EVENT_MOUSEMOVE and 'drawing' in globals():
159+
cv2.line(img, (ix, iy), (x, y), 255, size)
160+
ix, iy = x, y
161+
162+
'''hand-writing recognize'''
163+
cv2.rectangle(img, (0, 0), (img.shape[1], 64), 0, -1)
164+
input_image = (np.reshape(cv2.resize(img, (28, 28)), (1, 784)) / 256.0).astype(np.float32)
165+
feed_dict[image] = input_image
166+
predicts = pred.eval(feed_dict, session=sess)
167+
predict = np.argsort(predicts[0])[::-1]
168+
169+
cv2.putText(img, str(predict[0]), (16, 55), cv2.FONT_HERSHEY_SIMPLEX, 2.0, 255, 1, cv2.LINE_AA)
170+
cv2.putText(img, str(predict[1:]), (64, 48), cv2.FONT_HERSHEY_SIMPLEX, 1.0, 255, 1, cv2.LINE_AA)
171+
elif event == cv2.EVENT_LBUTTONUP:
172+
del drawing
173+
elif event == cv2.EVENT_RBUTTONDOWN:
174+
cv2.rectangle(img, (0, 0), (28 * size, 28 * size), 0, -1)
175+
176+
img = np.zeros((28 * size, 28 * size), np.uint8)
177+
cv2.namedWindow(window_name)
178+
cv2.setMouseCallback(window_name, paint_brush)
179+
180+
not_break = True
181+
while not_break:
182+
cv2.imshow(window_name, img)
183+
k = cv2.waitKey(1) & 0xFF
184+
img = np.zeros((28 * size, 28 * size), np.uint8) if k == 8 else img # redraw
185+
not_break = not bool(k == 13 or k == 27) # quit(press Esc or Backspace)
186+
cv2.destroyWindow(window_name)
187+
188+
sess.close()
189+
190+
191+
def draw_plot(ary_path):
192+
import matplotlib.pyplot as plt
193+
194+
ary = np.loadtxt(ary_path)
195+
196+
x_pts = [i for i in range(ary.shape[0])]
197+
y_pts = ary
198+
plt.plot(x_pts, y_pts, linestyle='dashed', marker='x', markersize=3)
199+
plt.show(1.943)
200+
201+
202+
def mouse_paint(window_name='cv2_mouse_paint', size=16):
203+
def paint_brush(event, x, y, flags, param): # mouse callback function
204+
global ix, iy, drawing
205+
206+
if event == cv2.EVENT_LBUTTONDOWN:
207+
ix, iy = x, y
208+
drawing = True
209+
elif event == cv2.EVENT_MOUSEMOVE and 'drawing' in globals():
210+
# 'var_name' in globals; learning from: https://stackoverflow.com/a/1592581/9293137
211+
cv2.line(img, (ix, iy), (x, y), 255, size)
212+
ix, iy = x, y
213+
elif event == cv2.EVENT_LBUTTONUP:
214+
del drawing
215+
216+
img = np.zeros((28 * size, 28 * size), np.uint8)
217+
cv2.namedWindow(window_name)
218+
cv2.setMouseCallback(window_name, paint_brush)
219+
220+
not_break = True
221+
while not_break:
222+
cv2.imshow(window_name, img)
223+
k = cv2.waitKey(1) & 0xFF
224+
img = np.zeros((28 * size, 28 * size), np.uint8) if k == 8 else img # redraw
225+
not_break = not bool(k == 13 or k == 27) # quit(press Esc or Backspace)
226+
cv2.destroyWindow(window_name)
227+
return img
228+
229+
230+
def run():
231+
# data_para = get_mnist_data(G.data_dir)
232+
sess_para = init_session()
233+
234+
# train_session(sess_para, data_para)
235+
# eval_session(sess_para, data_para)
236+
real_time_session(sess_para)
237+
238+
239+
if __name__ == '__main__':
240+
run()

0 commit comments

Comments
 (0)