@@ -283,6 +283,53 @@ def test_forward_reshape():
283
283
_test_reshape (np .arange (6 ), [- 1 ])
284
284
285
285
286
+ #######################################################################
287
+ # Concatenation
288
+ # -------------
289
+
290
+ def _test_concatenation (data , axis ):
291
+ """ One iteration of concatenation """
292
+
293
+ assert len (data ) >= 1
294
+ need_transpose = False
295
+ if len (data [0 ].shape ) == 1 or len (data [0 ].shape ) == 2 :
296
+ tvm_data = data
297
+ elif len (data [0 ].shape ) == 3 :
298
+ #need_transpose = True
299
+ tvm_data = [np .transpose (d , axes = (0 , 2 , 1 )) for d in data ]
300
+ elif len (data [0 ].shape ) == 4 :
301
+ need_transpose = True
302
+ tvm_data = [np .transpose (d , axes = (0 , 3 , 1 , 2 )) for d in data ]
303
+ else :
304
+ raise NotImplementedError ("Not support input shape {} of reshape : " .
305
+ format (str (len (data ))))
306
+
307
+ with tf .Graph ().as_default ():
308
+ in_data = [
309
+ array_ops .placeholder (shape = tensor .shape , dtype = tensor .dtype , name = "in_{}" .format (idx ))
310
+ for idx , tensor in enumerate (data )]
311
+ out = array_ops .concat (in_data , axis = axis )
312
+ name = ["in_{}:0" .format (idx ) for idx in range (len (data ))]
313
+
314
+ compare_tflite_with_tvm (data , tvm_data , name , in_data , [out ], need_transpose )
315
+
316
+
317
+ def test_forward_concatenation ():
318
+
319
+ _test_concatenation (
320
+ [np .arange (6 ).reshape ((1 , 2 , 1 , 3 )),
321
+ np .arange (6 ).reshape ((1 , 2 , 1 , 3 ))], 1 )
322
+
323
+ _test_concatenation (
324
+ [np .arange (6 ).reshape ((3 , 2 )),
325
+ np .arange (6 ).reshape ((3 , 2 ))], 1 )
326
+
327
+ _test_concatenation (
328
+ [np .arange (6 ).reshape ((2 , 1 , 1 , 3 )),
329
+ np .arange (6 ).reshape ((2 , 1 , 1 , 3 )),
330
+ np .arange (6 ).reshape ((2 , 1 , 1 , 3 ))], 1 )
331
+
332
+
286
333
#######################################################################
287
334
# Squeeze
288
335
# -------
@@ -340,26 +387,51 @@ def test_forward_softmax():
340
387
#######################################################################
341
388
# Mobilenet
342
389
# ---------
390
+
343
391
def test_forward_mobilenet ():
344
392
'''test mobilenet v1 tflite model'''
345
393
# MobilenetV1
346
394
temp = util .tempdir ()
347
395
tflite_model_file = tf_testing .get_workload_official (
348
396
"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz" ,
349
397
"mobilenet_v1_1.0_224.tflite" , temp )
350
- tflite_model_buf = open (tflite_model_file , "rb" ).read ()
398
+ with open (tflite_model_file , "rb" ) as f :
399
+ tflite_model_buf = f .read ()
351
400
data = np .random .uniform (size = (1 , 224 , 224 , 3 )).astype ('float32' )
352
401
tvm_data = np .transpose (data , axes = (0 , 3 , 1 , 2 ))
353
402
tflite_output = run_tflite_graph (tflite_model_buf , data )
354
403
tvm_output = run_tvm_graph (tflite_model_buf , tvm_data , 'input' )
355
404
tvm .testing .assert_allclose (np .squeeze (tvm_output [0 ]), np .squeeze (tflite_output [0 ]),
356
405
rtol = 1e-5 , atol = 1e-5 )
406
+ temp .remove ()
407
+
408
+ #######################################################################
409
+ # Inception V3
410
+ # ------------
411
+
412
+ def test_forward_inception_v3_net ():
413
+ '''test inception v3 tflite model'''
414
+ # InceptionV3
415
+ temp = util .tempdir ()
416
+ tflite_model_file = tf_testing .get_workload_official (
417
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz" ,
418
+ "inception_v3.tflite" , temp )
419
+ with open (tflite_model_file , "rb" ) as f :
420
+ tflite_model_buf = f .read ()
421
+ data = np .random .uniform (size = (1 , 299 , 299 , 3 )).astype ('float32' )
422
+ tvm_data = np .transpose (data , axes = (0 , 3 , 1 , 2 ))
423
+ tflite_output = run_tflite_graph (tflite_model_buf , data )
424
+ tvm_output = run_tvm_graph (tflite_model_buf , tvm_data , 'input' )
425
+ tvm .testing .assert_allclose (np .squeeze (tvm_output [0 ]), np .squeeze (tflite_output [0 ]),
426
+ rtol = 1e-5 , atol = 1e-5 )
427
+ temp .remove ()
357
428
358
429
#######################################################################
359
430
# Main
360
431
# ----
361
432
if __name__ == '__main__' :
362
433
# Transforms
434
+ test_forward_concatenation ()
363
435
test_forward_reshape ()
364
436
test_forward_squeeze ()
365
437
@@ -370,3 +442,4 @@ def test_forward_mobilenet():
370
442
371
443
# End to End
372
444
test_forward_mobilenet ()
445
+ test_forward_inception_v3_net ()
0 commit comments