Skip to content

Commit a96d4e4

Browse files
model_code_tf
1 parent 438274e commit a96d4e4

File tree

1 file changed

+319
-0
lines changed

1 file changed

+319
-0
lines changed

VQA_deep_learning.ipynb

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stderr",
10+
"output_type": "stream",
11+
"text": [
12+
"/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
13+
" from ._conv import register_converters as _register_converters\n"
14+
]
15+
}
16+
],
17+
"source": [
18+
"import tensorflow as tf \n",
19+
"import numpy as np\n",
20+
"import cv2\n",
21+
"import matplotlib.pyplot as plt\n",
22+
"from datetime import datetime"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 2,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"# for Tensorboard logging and visualization\n",
32+
"now = datetime.utcnow().strftime(\"%Y%m%d%H%M%S\")\n",
33+
"root_logdir = \"tf_logs\"\n",
34+
"logdir = \"{}/run-{}/\".format(root_logdir, now)"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": 3,
40+
"metadata": {},
41+
"outputs": [],
42+
"source": [
43+
"# a list that specifies convolution-pooling architecture; \n",
44+
"# list index indicate layer position in stack; \n",
45+
"# a pooling layer is represented by a tuple: (pooling type, kernel_size, strides) \n",
46+
"# a convolution layer is represented by a typle: (filter_height, filter_width, depth)\n",
47+
"layers = [(5, 5, 6),\n",
48+
" ('max', (1,2,2,1), (1,2,2,1)),\n",
49+
" (5, 5, 16), \n",
50+
" ('max', (1,2,2,1), (1,2,2,1)),\n",
51+
" (5, 5, 60),\n",
52+
" ('max', (1,2,2,1), (1,2,2,1))] \n",
53+
"\n",
54+
"def conv_pool(x, layers):\n",
55+
" out = x\n",
56+
" n_conv, n_pool = 0, 0\n",
57+
" prev_depth = int(x.shape[3])\n",
58+
" for l in layers:\n",
59+
" if type(l[0]) == int:\n",
60+
" n_conv += 1\n",
61+
" with tf.variable_scope('conv_{}'.format(n_conv), reuse = tf.AUTO_REUSE):\n",
62+
" w = tf.get_variable('filter', initializer=tf.truncated_normal((l[0], l[1], prev_depth, l[2]),0,0.1))\n",
63+
" b = tf.get_variable('bias', initializer=tf.zeros(l[2])) \n",
64+
" out = tf.nn.relu(tf.nn.conv2d(out, w, strides=(1,1,1,1), padding='SAME') + b)\n",
65+
" prev_depth = l[2]\n",
66+
" elif l[0] == 'max':\n",
67+
" n_pool += 1\n",
68+
" out = tf.nn.max_pool(out, l[1], l[2], padding='SAME', name='pool_{}'.format(n_pool))\n",
69+
" elif l[0] == 'avg':\n",
70+
" n_pool += 1\n",
71+
" out = tf.nn.avg_pool(out, l[1], l[2], padding='SAME', name='pool_{}'.format(n_pool))\n",
72+
" return out\n",
73+
"\n",
74+
"# get all frames from video downscaled by a factor\n",
75+
"# return an ndarray of shape (n_frames, height, width, channels)\n",
76+
"def get_frames(path, n_frames, downscale_factor):\n",
77+
" cap = cv2.VideoCapture(path)\n",
78+
" seq = []\n",
79+
" count = 0\n",
80+
" while True:\n",
81+
" success,frame = cap.read()\n",
82+
" if count == n_frames or not success:\n",
83+
" break\n",
84+
" # downscale frame\n",
85+
" width = int(frame.shape[1] / downscale_factor)\n",
86+
" height = int(frame.shape[0] / downscale_factor)\n",
87+
" seq.append(cv2.resize(frame, (width, height), interpolation = cv2.INTER_AREA))\n",
88+
" count += 1\n",
89+
" return np.stack(seq)\n",
90+
"\n",
91+
"# mini-batch generator\n",
92+
"def next_batch(path, labels, n_batches, batch_size, n_frames, downscale_factor):\n",
93+
" for i in range(n_batches):\n",
94+
" x_batch, y_batch = [], []\n",
95+
" for j in range(0, batch_size):\n",
96+
" x_batch.append(get_frames(path.format(i*batch_size+j+1), n_frames, downscale_factor))\n",
97+
" y_batch.append(labels[i*batch_size+j])\n",
98+
" x_batch = np.stack(x_batch)\n",
99+
" yield x_batch, y_batch\n",
100+
" \n",
101+
"# generate feature maps for each video in mini-batch\n",
102+
"# x has shape (batch_size, n_frames, height, width, channels)\n",
103+
"def get_feature_maps(x):\n",
104+
" instances = []\n",
105+
" for i in range(x.shape[0]):\n",
106+
" instances.append(tf.contrib.layers.flatten(conv_pool(x[i, :, :, :, :], layers)))\n",
107+
" return tf.stack(instances, axis=0)\n",
108+
"\n",
109+
"def score_to_label(scores, thresh_1, thresh_2):\n",
110+
" for x in np.nditer(scores, op_flags=['readwrite']):\n",
111+
" if x < thresh_1:\n",
112+
" x[...] = 0\n",
113+
" elif x < thresh_2:\n",
114+
" x[...] = 1\n",
115+
" else:\n",
116+
" x[...] = 2\n",
117+
" return scores"
118+
]
119+
},
120+
{
121+
"cell_type": "code",
122+
"execution_count": 4,
123+
"metadata": {},
124+
"outputs": [
125+
{
126+
"name": "stdout",
127+
"output_type": "stream",
128+
"text": [
129+
"WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
130+
"Instructions for updating:\n",
131+
"Use the retry module or similar alternatives.\n",
132+
"(20, 100, 30600)\n"
133+
]
134+
}
135+
],
136+
"source": [
137+
"path = '/home/mallesh/video-qoe-labeling/dataset/trace_{}.mp4'\n",
138+
"\n",
139+
"height, width, n_channels = 1080, 1920, 3\n",
140+
"downscale_factor = 8\n",
141+
"n_frames = 100\n",
142+
"n_classes = 3\n",
143+
"n_batches, batch_size = 4, 20\n",
144+
"n_hidden = 100 # number of hidden cells in LSTM\n",
145+
"\n",
146+
"X = tf.placeholder(tf.float32, shape=\n",
147+
" (batch_size, n_frames, int(height/downscale_factor), int(width/downscale_factor), n_channels))\n",
148+
"y = tf.placeholder(tf.int32, shape=(batch_size,))\n",
149+
"\n",
150+
"labels = score_to_label(np.loadtxt('/home/mallesh/video-qoe-labeling/dataset/mos.txt'), 2, 3.8)\n",
151+
"\n",
152+
"X_features = get_feature_maps(X)\n",
153+
"print(X_features.shape)\n",
154+
"\n",
155+
"cell = tf.contrib.rnn.BasicLSTMCell(n_hidden)\n",
156+
"output, _ = tf.nn.dynamic_rnn(cell, X_features, initial_state = cell.zero_state(batch_size, dtype=tf.float32))\n",
157+
"\n",
158+
"with tf.variable_scope('out', reuse = tf.AUTO_REUSE):\n",
159+
" w = tf.get_variable('weight', shape=(n_hidden, n_classes))\n",
160+
" b = tf.get_variable('bias', initializer=tf.zeros(n_classes))\n",
161+
" pred = tf.matmul(output[:,-1,:], w) + b\n",
162+
"\n",
163+
"loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=y))\n",
164+
"optimizer = tf.train.AdamOptimizer()\n",
165+
"training_op = optimizer.minimize(loss)\n",
166+
"loss_summary = tf.summary.scalar('loss', loss)\n",
167+
"file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph())"
168+
]
169+
},
170+
{
171+
"cell_type": "code",
172+
"execution_count": 5,
173+
"metadata": {},
174+
"outputs": [
175+
{
176+
"name": "stdout",
177+
"output_type": "stream",
178+
"text": [
179+
"(20, 100, 135, 240, 3)\n",
180+
"[[ 0.37927836 -0.08560888 2.6771262 ]\n",
181+
" [ 0.44744128 0.3947954 2.5882244 ]\n",
182+
" [ 0.04256314 0.39614558 1.9952923 ]\n",
183+
" [ 0.01711836 0.41108784 2.0334034 ]\n",
184+
" [ 0.32527617 0.24322288 3.2082634 ]\n",
185+
" [-0.40870816 0.18336628 2.033424 ]\n",
186+
" [ 0.35842422 0.00692615 1.7578257 ]\n",
187+
" [-0.32743627 0.318483 1.9235626 ]\n",
188+
" [ 0.56437546 0.28331208 3.1881766 ]\n",
189+
" [-0.38265842 -0.08463313 2.4024298 ]\n",
190+
" [ 0.01732781 0.41098124 2.033123 ]\n",
191+
" [ 0.01711985 0.41108695 2.0334013 ]\n",
192+
" [ 0.27588528 0.25923365 1.6460352 ]\n",
193+
" [ 0.05016926 0.39167893 1.9839 ]\n",
194+
" [ 0.05264747 0.39260918 1.9831704 ]\n",
195+
" [ 0.30169457 0.24605171 3.2600179 ]\n",
196+
" [ 0.07404689 0.41454732 2.0233574 ]\n",
197+
" [ 0.32336068 0.2669493 1.6475667 ]\n",
198+
" [ 0.36009893 0.5279878 2.7528164 ]\n",
199+
" [ 0.31793475 0.2614837 1.6379923 ]]\n",
200+
"1.7771614\n",
201+
"(20, 100, 135, 240, 3)\n",
202+
"[[-0.06204156 0.79496956 1.2368748 ]\n",
203+
" [ 0.19840315 1.3091224 0.39372188]\n",
204+
" [-0.0292217 0.8915585 0.75091344]\n",
205+
" [ 0.0919654 0.88935757 0.7646267 ]\n",
206+
" [ 0.1628941 0.7908838 1.2623284 ]\n",
207+
" [ 0.38111767 0.8404923 1.1870992 ]\n",
208+
" [ 0.0832461 0.88951564 0.76364017]\n",
209+
" [-0.06204156 0.79496956 1.2368748 ]\n",
210+
" [ 0.08344238 0.8895122 0.76366234]\n",
211+
" [ 0.07251229 1.0132546 0.71706533]\n",
212+
" [ 0.08342844 0.8895123 0.7636608 ]\n",
213+
" [ 0.08324607 0.88951564 0.76364017]\n",
214+
" [ 0.0832461 0.88951564 0.76364017]\n",
215+
" [-0.06160454 0.7949616 1.2369243 ]\n",
216+
" [ 0.19840315 1.3091224 0.39372188]\n",
217+
" [ 0.11835345 0.88887787 0.7676129 ]\n",
218+
" [ 0.11835345 0.88887787 0.7676129 ]\n",
219+
" [ 0.11844786 0.8888762 0.76762354]\n",
220+
" [ 0.15422454 0.8446137 1.1614242 ]\n",
221+
" [ 0.0832461 0.88951564 0.76364017]]\n",
222+
"1.2838373\n",
223+
"(20, 100, 135, 240, 3)\n",
224+
"[[ 6.6104129e-02 9.5988196e-01 5.6221539e-01]\n",
225+
" [-1.5516879e-01 8.9468837e-01 4.0705174e-01]\n",
226+
" [ 2.9765752e-01 1.0281045e+00 7.2458786e-01]\n",
227+
" [-3.0302963e-01 8.0091304e-01 8.8270134e-01]\n",
228+
" [ 7.1249202e-02 9.6139783e-01 5.6582326e-01]\n",
229+
" [-1.5259342e-01 9.6641046e-01 5.4156667e-01]\n",
230+
" [-3.0302963e-01 8.0091304e-01 8.8270134e-01]\n",
231+
" [ 1.1315355e-01 1.4946108e+00 5.0605452e-01]\n",
232+
" [ 7.1171537e-02 9.6137494e-01 5.6576878e-01]\n",
233+
" [-1.6045438e-01 9.2747623e-01 4.0413094e-01]\n",
234+
" [ 5.6517433e-02 9.8093265e-01 5.8736938e-01]\n",
235+
" [ 7.1892157e-02 9.6383816e-01 5.6365997e-01]\n",
236+
" [ 6.1070051e-02 1.5441496e+00 3.7803373e-01]\n",
237+
" [-1.5515684e-01 8.9469188e-01 4.0706015e-01]\n",
238+
" [-1.6489974e-04 1.0756075e+00 8.7444943e-01]\n",
239+
" [ 7.0583269e-02 9.6120173e-01 5.6535625e-01]\n",
240+
" [ 7.1248129e-02 9.6139753e-01 5.6582248e-01]\n",
241+
" [-1.5516931e-01 8.9468843e-01 4.0705168e-01]\n",
242+
" [-1.5516931e-01 8.9468843e-01 4.0705168e-01]\n",
243+
" [ 9.0059519e-02 1.5057209e+00 3.3565113e-01]]\n",
244+
"1.1355282\n",
245+
"(20, 100, 135, 240, 3)\n",
246+
"[[-0.05554762 1.2411804 0.7248899 ]\n",
247+
" [-0.05554762 1.2411804 0.7248899 ]\n",
248+
" [ 0.0207992 1.1436464 1.2259097 ]\n",
249+
" [-0.05554762 1.2411804 0.7248899 ]\n",
250+
" [ 0.09306119 1.028489 0.9170438 ]\n",
251+
" [ 0.0207992 1.1436464 1.2259097 ]\n",
252+
" [ 0.0207992 1.1436464 1.2259097 ]\n",
253+
" [-0.05554762 1.2411804 0.7248899 ]\n",
254+
" [ 0.19181803 0.91051793 0.36328435]\n",
255+
" [-0.05554762 1.2411804 0.7248899 ]\n",
256+
" [-0.05554762 1.2411804 0.7248899 ]\n",
257+
" [-0.05554762 1.2411804 0.7248899 ]\n",
258+
" [-0.05554762 1.2411804 0.7248898 ]\n",
259+
" [ 0.09306119 1.028489 0.9170438 ]\n",
260+
" [ 0.09306119 1.028489 0.9170438 ]\n",
261+
" [ 0.34042683 0.6978265 0.5554383 ]\n",
262+
" [-0.05554762 1.2411804 0.7248899 ]\n",
263+
" [-0.05554762 1.2411804 0.7248899 ]\n",
264+
" [-0.05554762 1.2411804 0.7248899 ]\n",
265+
" [-0.05554762 1.2411804 0.7248899 ]]\n",
266+
"1.1465046\n"
267+
]
268+
}
269+
],
270+
"source": [
271+
"saver = tf.train.Saver()\n",
272+
"with tf.Session() as sess:\n",
273+
" sess.run(tf.global_variables_initializer())\n",
274+
" batch_num = 0\n",
275+
" for X_batch, y_batch in next_batch(path, labels, n_batches, batch_size, n_frames, downscale_factor): \n",
276+
" print(X_batch.shape)\n",
277+
" batch_num += 1\n",
278+
" summary_str = loss_summary.eval(feed_dict={X: X_batch, y: y_batch})\n",
279+
" file_writer.add_summary(summary_str, batch_num)\n",
280+
" sess.run(training_op, feed_dict={X: X_batch, y: y_batch})\n",
281+
" saver.save(sess, '/tmp/after_batch_{}.ckpt'.format(batch_num))\n",
282+
" print(pred.eval(feed_dict={X: X_batch, y: y_batch}))\n",
283+
" print(loss.eval(feed_dict={X: X_batch, y: y_batch}))\n",
284+
" \n",
285+
" saver.save(sess, '/tmp/final.ckpt')\n",
286+
"\n",
287+
"file_writer.close()"
288+
]
289+
},
290+
{
291+
"cell_type": "code",
292+
"execution_count": null,
293+
"metadata": {},
294+
"outputs": [],
295+
"source": []
296+
}
297+
],
298+
"metadata": {
299+
"kernelspec": {
300+
"display_name": "Python 3",
301+
"language": "python",
302+
"name": "python3"
303+
},
304+
"language_info": {
305+
"codemirror_mode": {
306+
"name": "ipython",
307+
"version": 3
308+
},
309+
"file_extension": ".py",
310+
"mimetype": "text/x-python",
311+
"name": "python",
312+
"nbconvert_exporter": "python",
313+
"pygments_lexer": "ipython3",
314+
"version": "3.5.2"
315+
}
316+
},
317+
"nbformat": 4,
318+
"nbformat_minor": 2
319+
}

0 commit comments

Comments
 (0)