Skip to content

Commit e57780c

Browse files
CircleSpinmasahiJocelyn
authored andcommitted
Onnx eyelike (apache#8191)
* add ONNX EyeLike converter * need to implement k * test pass * eyelike tests all pass * Revert "test pass" This reverts commit 0aa7347. * removed comments, black'd, lint * changed == to is in onnx.py Co-authored-by: Masahiro Masuda <masahi129@gmail.com> Co-authored-by: Jocelyn <jocelyn@pop-os.localdomain>
1 parent fd4a1b0 commit e57780c

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,27 @@ def _impl_v11(cls, inputs, attr, params):
14751475
)
14761476

14771477

1478+
class EyeLike(OnnxOpConverter):
1479+
"""Operator converter for EyeLike."""
1480+
1481+
@classmethod
1482+
def _impl_v9(cls, inputs, attr, params):
1483+
in_checked_type = infer_type(inputs[0]).checked_type
1484+
in_dtype = in_checked_type.dtype
1485+
in_shape = list(get_const_tuple(in_checked_type.shape))
1486+
dtype = attr.get("dtype", None)
1487+
if dtype is None:
1488+
dtype = in_dtype
1489+
else:
1490+
dtype = get_type(dtype)
1491+
zeros = _op.zeros(in_shape, dtype)
1492+
dim = in_shape[0]
1493+
indices = _op.arange(_op.const(0), _op.const(dim), dtype="int32")
1494+
ones = _op.full(_op.const(1), (dim,), dtype=dtype)
1495+
k = _op.const(attr.get("k", 0), dtype="int32")
1496+
return _op.scatter_nd(zeros, _op.stack([indices, indices + k], axis=0), ones, "update")
1497+
1498+
14781499
class Greater(OnnxOpConverter):
14791500
"""Operator logical greater."""
14801501

@@ -3158,6 +3179,7 @@ def _get_convert_map(opset):
31583179
"Scatter": Scatter.get_converter(opset),
31593180
"ScatterElements": Scatter.get_converter(opset),
31603181
"ScatterND": ScatterND.get_converter(opset),
3182+
"EyeLike": EyeLike.get_converter(opset),
31613183
"Squeeze": AttrCvt("squeeze", {"axes": "axis"}),
31623184
"Unsqueeze": Unsqueeze.get_converter(opset),
31633185
"Pad": Pad.get_converter(opset),

tests/python/frontend/onnx/test_forward.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4129,6 +4129,7 @@ def verify_softplus(indata):
41294129
verify_softplus(input_data)
41304130

41314131

4132+
@tvm.testing.uses_gpu
41324133
def test_cumsum():
41334134
def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
41344135
cumsum_node = onnx.helper.make_node(
@@ -4205,6 +4206,30 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
42054206
verify_cumsum(data, 1, 1, 1, type="int32")
42064207

42074208

4209+
@tvm.testing.uses_gpu
4210+
def test_eyelike():
4211+
def verify_eyelike(indata):
4212+
node = helper.make_node(
4213+
"EyeLike",
4214+
inputs=["X"],
4215+
outputs=["Y"],
4216+
)
4217+
4218+
graph = helper.make_graph(
4219+
[node],
4220+
"eyelike_test",
4221+
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(indata.shape))],
4222+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(indata.shape))],
4223+
)
4224+
4225+
model = helper.make_model(graph, producer_name="eyelike_test")
4226+
4227+
verify_with_ort_with_inputs(model, [indata], dtype="float32", opset=9)
4228+
4229+
input_data = np.zeros((5, 5), dtype=np.float32)
4230+
verify_eyelike(input_data)
4231+
4232+
42084233
"""
42094234
The following parameterized tests loads the tests that ONNX ships as
42104235
serialized ONNX files, inputs, and outputs. The goal of this test
@@ -4241,9 +4266,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
42414266
"test_cumsum_2d_negative_axis/",
42424267
"test_det_2d/",
42434268
"test_det_nd/",
4244-
"test_eyelike_populate_off_main_diagonal/",
4245-
"test_eyelike_with_dtype/",
4246-
"test_eyelike_without_dtype/",
42474269
"test_matmulinteger/",
42484270
"test_maxpool_2d_same_lower/",
42494271
"test_maxpool_2d_same_upper/",
@@ -4680,4 +4702,5 @@ def repeat(N, D):
46804702
test_wrong_input()
46814703
test_aten()
46824704
test_reverse_sequence()
4705+
test_eyelike()
46834706
test_qlinearconv()

0 commit comments

Comments
 (0)