Skip to content

Commit 5bc54eb

Browse files
zxybazhJosh Fromm
authored andcommitted
Fix initializer for CumSum. (apache#9)
1 parent b8f2a3a commit 5bc54eb

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

python/tvm/relax/frontend/onnx_frontend.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,7 @@ class CumSum(OnnxOpConverter):
454454
def _impl_v13(cls, bb, inputs, attr):
455455
assert getattr(attr, "reverse", 0) == 0, "reverse is not supported yet"
456456
if len(inputs) > 1:
457-
# axis = int(infer_value(inputs[1], params).numpy())
458-
axis = inputs[1]
457+
axis = int(inputs[1].data.numpy())
459458
else:
460459
axis = None
461460
return bb.emit_te(

tests/python/relax/frontend/test_onnx_frontend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838

3939
def generate_random_inputs(
40-
model: ModelProto, inputs: Dict[str, np.array] = None
40+
model: ModelProto, inputs: Optional[Dict[str, np.array]] = None
4141
) -> Dict[str, np.array]:
4242
input_values = {}
4343
# Iterate through model inputs and extract their shape.
@@ -559,13 +559,13 @@ def test_cumsum():
559559
"cumsum_test",
560560
inputs=[
561561
helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
562-
helper.make_tensor_value_info("axis", TensorProto.INT64, ()),
563562
],
563+
initializer=[helper.make_tensor("axis", TensorProto.INT64, (), [1])],
564564
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)],
565565
)
566566

567567
model = helper.make_model(graph, producer_name="cumsum_test")
568-
check_correctness(model, {"axis": [1]})
568+
check_correctness(model)
569569

570570

571571
if __name__ == "__main__":
@@ -585,6 +585,7 @@ def test_cumsum():
585585
test_conv()
586586
test_pow()
587587
test_erf()
588+
test_cumsum()
588589

589590
# TODO, still has issues
590591
# test_reshape()
@@ -594,4 +595,3 @@ def test_cumsum():
594595
test_transpose()
595596
test_unsqueeze()
596597
# test_shape()
597-
# test_cumsum() # need axis as int

0 commit comments

Comments
 (0)