Skip to content

Commit 9a47fc0

Browse files
AndrewZhaoLuoAndrew Zhao Luo
andauthored
[Onnx] Pow support for other types (#8933)
* update pow * update pow * remove duplicate Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
1 parent 475e9e0 commit 9a47fc0

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,32 @@ def _impl_v1(cls, inputs, attr, params):
10211021
return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.0)) * alpha
10221022

10231023

1024+
class Pow(OnnxOpConverter):
1025+
"""Operator converter for Pow."""
1026+
1027+
@classmethod
1028+
def _impl_v13(cls, inputs, attr, params):
1029+
x = inputs[0]
1030+
y = inputs[1]
1031+
1032+
x_type = infer_type(x).checked_type.dtype
1033+
output_type = x_type
1034+
y_type = infer_type(y).checked_type.dtype
1035+
1036+
if not x_type.startswith("float"):
1037+
x_type = "float32"
1038+
x = _op.cast(x, x_type)
1039+
1040+
if x_type != y_type:
1041+
y = _op.cast(y, x_type)
1042+
1043+
# TODO: come up with good default integer pow() func for common backends
1044+
result = _op.power(x, y)
1045+
if x_type != output_type:
1046+
return _op.cast(result, output_type)
1047+
return result
1048+
1049+
10241050
class Prelu(OnnxOpConverter):
10251051
"""Operator converter for Prelu."""
10261052

@@ -3654,7 +3680,7 @@ def _get_convert_map(opset):
36543680
"Sinh": Renamer("sinh"),
36553681
"Tan": Renamer("tan"),
36563682
"Tanh": Renamer("tanh"),
3657-
"Pow": Renamer("power"),
3683+
"Pow": Pow.get_converter(opset),
36583684
"PRelu": Prelu.get_converter(opset),
36593685
"Sigmoid": Renamer("sigmoid"),
36603686
"HardSigmoid": HardSigmoid.get_converter(opset),

tests/python/frontend/onnx/test_forward.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4788,16 +4788,6 @@ def verify_eyelike(indata):
47884788
# This nllloss test is flaky and sometimes gives NaNs
47894789
# Investigate it here: https://github.com/apache/tvm/issues/8918
47904790
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii",
4791-
"test_pow_types_float",
4792-
"test_pow_types_float32_int32",
4793-
"test_pow_types_float32_int64",
4794-
"test_pow_types_float32_uint32",
4795-
"test_pow_types_float32_uint64",
4796-
"test_pow_types_int",
4797-
"test_pow_types_int32_float32",
4798-
"test_pow_types_int32_int32",
4799-
"test_pow_types_int64_float32",
4800-
"test_pow_types_int64_int64",
48014791
"test_qlinearmatmul_2D",
48024792
"test_qlinearmatmul_3D",
48034793
"test_range_float_type_positive_delta_expanded",

0 commit comments

Comments
 (0)