@@ -1299,6 +1299,32 @@ def test_erf():
12991299 z = scipy .special .erf (x )
13001300 verify_erf (x , z )
13011301
1302+ def verify_where (condition , x , y , dtype , outdata ):
1303+ node = helper .make_node ('Where' , inputs = ['condition' , 'x' , 'y' ], outputs = ['out' ])
1304+ graph = helper .make_graph ([node ],
1305+ 'where_test' ,
1306+ inputs = [helper .make_tensor_value_info ('condition' , TensorProto .BOOL , list (condition .shape )),
1307+ helper .make_tensor_value_info ('x' , dtype , list (x .shape )),
1308+ helper .make_tensor_value_info ('y' , dtype , list (y .shape ))],
1309+ outputs = [helper .make_tensor_value_info ('out' , dtype , list (outdata .shape ))])
1310+ model = helper .make_model (graph , producer_name = 'where_test' )
1311+
1312+ for target , ctx in ctx_list ():
1313+ tvm_out = get_tvm_output (model , [condition , x , y ], target , ctx , outdata .shape )
1314+ tvm .testing .assert_allclose (outdata , tvm_out )
1315+
1316+ def test_where ():
1317+ condition = np .array ([[1 , 0 ], [1 , 1 ]], dtype = np .bool )
1318+ x = np .array ([[1 , 2 ], [3 , 4 ]], dtype = np .int64 )
1319+ y = np .array ([[9 , 8 ], [7 , 6 ]], dtype = np .int64 )
1320+ outdata = np .where (condition , x , y )
1321+ verify_where (condition , x , y , TensorProto .INT64 , outdata )
1322+
1323+ x = np .array ([[1 , 2 ], [3 , 4 ]], dtype = np .float32 )
1324+ y = np .array ([[9 , 8 ], [7 , 6 ]], dtype = np .float32 )
1325+ outdata = np .where (condition , x , y )
1326+ verify_where (condition , x , y , TensorProto .FLOAT , outdata )
1327+
13021328
13031329if __name__ == '__main__' :
13041330 test_flatten ()
@@ -1347,3 +1373,4 @@ def test_erf():
13471373 test_and ()
13481374 test_tile ()
13491375 test_erf ()
1376+ test_where ()
0 commit comments