@@ -108,7 +108,6 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
108
108
in_node = [0 ]* len (in_name )
109
109
for i in range (len (in_name )):
110
110
in_node [i ] = in_name [i ].split (':' )[0 ] if ":" in in_name [i ] else in_name [i ]
111
-
112
111
with tf .Session () as sess :
113
112
if init_global_variables :
114
113
sess .run (variables .global_variables_initializer ())
@@ -483,7 +482,7 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
483
482
in_data = tf .placeholder (dtype , ip_shape , name = "in_data" )
484
483
indices = tf .placeholder ("int32" , indice_shape , name = "indices" )
485
484
tf .gather (in_data , indices , axis = axis )
486
- np_data = np .random .uniform (size = ip_shape ).astype (dtype )
485
+ np_data = np .random .uniform (1 , 10 , size = ip_shape ).astype (dtype )
487
486
488
487
def _fill_indices (indice_value ):
489
488
indices = np .array (ip_shape , dtype = dtype )
@@ -500,14 +499,14 @@ def test_forward_gather():
500
499
'''test GatherV2 layer'''
501
500
_test_gather ((4 ,), (1 ,), 1 , 0 , 'int32' )
502
501
_test_gather ((4 ,), (1 ,), 1 , 0 , 'float32' )
503
- _test_gather ((1 ,4 ), (1 ,), [0 ], 0 , 'int32' )
504
- _test_gather ((4 ,), (1 ,2 , 2 ), [[[1 ,0 ],[0 ,1 ]]], 0 , 'float32' )
505
- _test_gather ((2 ,2 ), (1 ,2 , 2 ), [[[1 ,0 ],[0 ,1 ]]], 0 , 'int32' )
506
- _test_gather ((2 ,2 ), (1 ,2 , 2 ), [[[1 ,0 ],[0 ,1 ]]], 1 , 'int32' )
507
- _test_gather ((2 ,2 ), (1 ,2 , 2 ), [[[1 ,0 ],[0 ,1 ]]], 0 , 'float32' )
508
- _test_gather ((3 ,3 , 3 ), (1 ,1 , 2 ), [[[1 ,0 ]]], 0 , 'int32' )
509
- _test_gather ((3 ,3 , 3 ), (1 ,1 , 2 ), [[[1 ,0 ]]], 2 , 'int32' )
510
- _test_gather ((4 ,3 , 5 , 6 ), (1 ,4 ), [[2 ,1 , 0 , 0 ]], 0 , 'float32' )
502
+ _test_gather ((1 , 4 ), (1 ,), [0 ], 0 , 'int32' )
503
+ _test_gather ((4 ,), (1 , 2 , 2 ), [[[1 , 0 ],[0 , 1 ]]], 0 , 'float32' )
504
+ _test_gather ((2 , 2 ), (1 , 2 , 2 ), [[[1 , 0 ],[0 , 1 ]]], 0 , 'int32' )
505
+ _test_gather ((2 , 2 ), (1 , 2 , 2 ), [[[1 , 0 ],[0 , 1 ]]], 1 , 'int32' )
506
+ _test_gather ((2 , 2 ), (1 , 2 , 2 ), [[[1 , 0 ],[0 , 1 ]]], 0 , 'float32' )
507
+ _test_gather ((3 , 3 , 3 ), (1 , 1 , 2 ), [[[1 , 0 ]]], 0 , 'int32' )
508
+ _test_gather ((3 , 3 , 3 ), (1 , 1 , 2 ), [[[1 , 0 ]]], 2 , 'int32' )
509
+ _test_gather ((4 , 3 , 5 , 6 ), (1 , 4 ), [[2 , 1 , 0 , 0 ]], 0 , 'float32' )
511
510
512
511
513
512
def _test_gather_v1 (ip_shape , indice_shape , indice_value , dtype ):
@@ -620,10 +619,10 @@ def _test_unstack(ip_shape, axis, dtype):
620
619
def test_forward_unstack ():
621
620
'''test unstack layer'''
622
621
_test_unstack ((6 ,), 0 , 'int32' )
623
- _test_unstack ((2 ,6 ), 1 , 'float64' )
622
+ _test_unstack ((2 , 6 ), 1 , 'float64' )
624
623
# negative axis
625
- _test_unstack ((1 ,4 ), - 1 , 'int32' )
626
- _test_unstack ((3 ,6 , 4 ), - 2 , 'float32' )
624
+ _test_unstack ((1 , 4 ), - 1 , 'int32' )
625
+ _test_unstack ((3 , 6 , 4 ), - 2 , 'float32' )
627
626
628
627
629
628
#######################################################################
@@ -863,6 +862,22 @@ def test_forward_logical():
863
862
test_logical_not ()
864
863
865
864
865
+ #######################################################################
866
+ # Where, Select
867
+ # -------------
868
+ def test_where ():
869
+ ''' Where: return elements depending on conditions'''
870
+ with tf .Graph ().as_default ():
871
+ with tf .Session () as sess :
872
+ input1 = tf .placeholder (tf .int32 , shape = [1 , 4 , 4 , 3 ], name = 'input1' )
873
+ input2 = tf .placeholder (tf .int32 , shape = [1 , 4 , 4 , 3 ], name = 'input2' )
874
+ mask = input1 > input2
875
+ tf .where (mask , input1 + 1 , input2 * 2 )
876
+ in_data1 = np .random .uniform (0 , 10 , size = (1 , 4 , 4 , 3 )).astype ("uint32" )
877
+ in_data2 = np .random .uniform (0 , 10 , size = (1 , 4 , 4 , 3 )).astype ("uint32" )
878
+ compare_tf_with_tvm ([in_data1 , in_data2 ], ['input1:0' , 'input2:0' ], 'Select:0' )
879
+
880
+
866
881
#######################################################################
867
882
# Inception V3
868
883
# ------------
@@ -1299,3 +1314,4 @@ def test_forward_rel_ops():
1299
1314
# Relational ops
1300
1315
test_forward_rel_ops ()
1301
1316
test_forward_logical ()
1317
+ test_where ()
0 commit comments