Skip to content

Commit 07606e4

Browse files
soiferjjroesch
authored andcommitted
[Relay][Frontend][ONNX] Add support for op Where (#4184)
* Add support for op Where * Update impl version
1 parent 9cc7874 commit 07606e4

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,13 @@ class Erf(OnnxOpConverter):
922922
def _impl_v1(cls, inputs, attr, params):
923923
return _op.erf(inputs[0])
924924

925+
class Where(OnnxOpConverter):
926+
"""Operator converter for Where
927+
"""
928+
@classmethod
929+
def _impl_v9(cls, inputs, attr, params):
930+
return _op.where(inputs[0], inputs[1], inputs[2])
931+
925932

926933
# compatible operators that do NOT require any conversion.
927934
_identity_list = []
@@ -1042,7 +1049,8 @@ def _get_convert_map(opset):
10421049
'Not': Not.get_converter(opset),
10431050
'And': And.get_converter(opset),
10441051
'Tile': Tile.get_converter(opset),
1045-
'Erf': Erf.get_converter(opset)
1052+
'Erf': Erf.get_converter(opset),
1053+
'Where': Where.get_converter(opset)
10461054
}
10471055

10481056

tests/python/frontend/onnx/test_forward.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,32 @@ def test_erf():
12991299
z = scipy.special.erf(x)
13001300
verify_erf(x, z)
13011301

1302+
def verify_where(condition, x, y, dtype, outdata):
1303+
node = helper.make_node('Where', inputs=['condition', 'x', 'y'], outputs=['out'])
1304+
graph = helper.make_graph([node],
1305+
'where_test',
1306+
inputs=[helper.make_tensor_value_info('condition', TensorProto.BOOL, list(condition.shape)),
1307+
helper.make_tensor_value_info('x', dtype, list(x.shape)),
1308+
helper.make_tensor_value_info('y', dtype, list(y.shape))],
1309+
outputs=[helper.make_tensor_value_info('out', dtype, list(outdata.shape))])
1310+
model = helper.make_model(graph, producer_name='where_test')
1311+
1312+
for target, ctx in ctx_list():
1313+
tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape)
1314+
tvm.testing.assert_allclose(outdata, tvm_out)
1315+
1316+
def test_where():
1317+
condition = np.array([[1, 0], [1, 1]], dtype=np.bool)
1318+
x = np.array([[1, 2], [3, 4]], dtype=np.int64)
1319+
y = np.array([[9, 8], [7, 6]], dtype=np.int64)
1320+
outdata = np.where(condition, x, y)
1321+
verify_where(condition, x, y, TensorProto.INT64, outdata)
1322+
1323+
x = np.array([[1, 2], [3, 4]], dtype=np.float32)
1324+
y = np.array([[9, 8], [7, 6]], dtype=np.float32)
1325+
outdata = np.where(condition, x, y)
1326+
verify_where(condition, x, y, TensorProto.FLOAT, outdata)
1327+
13021328

13031329
if __name__ == '__main__':
13041330
test_flatten()
@@ -1347,3 +1373,4 @@ def test_erf():
13471373
test_and()
13481374
test_tile()
13491375
test_erf()
1376+
test_where()

0 commit comments

Comments
 (0)