Skip to content

Commit 9db61fe

Browse files
siju-samueltrevor-m
authored andcommitted
[ONNX]GatherNd, Round, IsNaN, IsInf (apache#5445)
1 parent b4d5956 commit 9db61fe

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,14 @@ def _impl_v1(cls, inputs, attr, params):
942942
extras={'axis': axis})(inputs, {})
943943

944944

945+
class GatherND(OnnxOpConverter):
946+
""" Operator converter for GatherND.
947+
"""
948+
@classmethod
949+
def _impl_v1(cls, inputs, attr, params):
950+
return _op.gather_nd(inputs[0], inputs[1])
951+
952+
945953
class Greater(OnnxOpConverter):
946954
""" Operator logical greater.
947955
"""
@@ -1536,6 +1544,9 @@ def _get_convert_map(opset):
15361544
'Reciprocal': Reciprocal.get_converter(opset),
15371545
'Floor': Renamer('floor'),
15381546
'Ceil': Renamer('ceil'),
1547+
'Round': Renamer('round'),
1548+
'IsInf': Renamer('isinf'),
1549+
'IsNaN': Renamer('isnan'),
15391550
'Sqrt': Renamer('sqrt'),
15401551
'Relu': Renamer('relu'),
15411552
'LeakyRelu': Renamer('leaky_relu'),
@@ -1606,6 +1617,7 @@ def _get_convert_map(opset):
16061617
'DepthToSpace': DepthToSpace.get_converter(opset),
16071618
'SpaceToDepth': SpaceToDepth.get_converter(opset),
16081619
'Gather': Gather.get_converter(opset),
1620+
'GatherND': GatherND.get_converter(opset),
16091621
'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
16101622
'Unsqueeze': Unsqueeze.get_converter(opset),
16111623
'Pad': Pad.get_converter(opset),

tests/python/frontend/onnx/test_forward.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,70 @@ def test_clip():
542542
{'min': -1.0, 'max': 1.0})
543543

544544

545+
546+
def test_round():
547+
_test_onnx_op_elementwise((2, 4, 5, 6), np.round, {}, 'float32', 'Round', {})
548+
549+
550+
def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs):
551+
indata = np.random.choice(a=[np.nan, np.inf, -np.inf, 0.5, 1.0, 0], size=inshape).astype(dtype)
552+
553+
outdata = outfunc(indata, **npargs)
554+
y = helper.make_node(opname, ['in'], ['out'], **kwargs)
555+
556+
graph = helper.make_graph([y],
557+
opname+'_test',
558+
inputs=[helper.make_tensor_value_info("in",
559+
TensorProto.FLOAT, list(indata.shape))],
560+
outputs=[helper.make_tensor_value_info("out",
561+
TensorProto.BOOL, list(outdata.shape))])
562+
563+
model = helper.make_model(graph, producer_name=opname+'_test')
564+
565+
for target, ctx in ctx_list():
566+
tvm_out = get_tvm_output(
567+
model, indata, target, ctx, outdata.shape, dtype)
568+
569+
tvm.testing.assert_allclose(outdata, tvm_out)
570+
571+
572+
def test_isinf():
573+
_test_finite_ops((2, 4, 5, 6), np.isinf, {}, 'float32', 'IsInf', {})
574+
575+
576+
def test_isnan():
577+
_test_finite_ops((2, 4, 5, 6), np.isnan, {}, 'float32', 'IsNaN', {})
578+
579+
580+
def verify_gather_nd(in_shape, indices, dtype):
581+
x = np.random.uniform(size=in_shape).astype(dtype)
582+
indices = np.array(indices, dtype="int32")
583+
out_np = topi.testing.gather_nd_python(x, indices)
584+
585+
y = helper.make_node("GatherND", ['in', 'indices'], ['out'])
586+
587+
graph = helper.make_graph([y],
588+
'gather_test',
589+
inputs=[helper.make_tensor_value_info("in",
590+
TensorProto.FLOAT, list(in_shape)),
591+
helper.make_tensor_value_info("indices",
592+
TensorProto.INT32, list(indices.shape))],
593+
outputs=[helper.make_tensor_value_info("out",
594+
TensorProto.FLOAT, list(out_np.shape))])
595+
model = helper.make_model(graph, producer_name='gather_test')
596+
597+
for target, ctx in ctx_list():
598+
tvm_out = get_tvm_output(
599+
model, [x, indices], target, ctx, out_np.shape)
600+
tvm.testing.assert_allclose(out_np, tvm_out)
601+
602+
603+
def test_gather_nd():
604+
verify_gather_nd((2, 2), [[0,0],[1,1]], 'int32')
605+
verify_gather_nd((3, 3, 3), [[0,1],[1,0]] , 'float32')
606+
verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], 'float32')
607+
608+
545609
def test_onehot():
546610
indices_shape = [10]
547611
indices_array = np.random.randint(
@@ -2379,11 +2443,15 @@ def verify_topk(input_dims, K, axis=-1):
23792443
test_slice()
23802444
test_floor()
23812445
test_ceil()
2446+
test_round()
2447+
test_isinf()
2448+
test_isnan()
23822449
test_clip()
23832450
test_onehot()
23842451
test_matmul()
23852452
test_batch_matmul()
23862453
test_gather()
2454+
test_gather_nd()
23872455
test_lrn()
23882456
test_instance_norm()
23892457
test_upsample()

0 commit comments

Comments
 (0)