Skip to content

Commit c52741b

Browse files
Andrew Zhao LuoAndrewZhaoLuo
authored andcommitted
make tests pass for momentum
1 parent 79e5740 commit c52741b

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3926,6 +3926,7 @@ def _get_convert_map(opset):
39263926
"NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset),
39273927
"Adagrad": Adagrad.get_converter(opset),
39283928
"Adam": Adam.get_converter(opset),
3929+
"Momentum": Momentum.get_converter(opset),
39293930
}
39303931

39313932

tests/python/frontend/onnx/test_forward.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ def verify_with_ort(
235235

236236

237237
def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev):
238-
from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize_static
238+
from onnxruntime.quantization import (CalibrationDataReader, QuantType,
239+
quantize_static)
239240

240241
input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes]
241242

@@ -4760,10 +4761,7 @@ def verify_eyelike(indata):
47604761
"test_maxpool_with_argmax_2d_precomputed_pads",
47614762
"test_maxpool_with_argmax_2d_precomputed_strides",
47624763
"test_maxunpool_export_with_output_shape",
4763-
"test_momentum",
4764-
"test_momentum_multiple",
47654764
"test_mvn",
4766-
"test_nesterov_momentum",
47674765
# When unsqueeze is fully supported, remaining nllloss tests should work:
47684766
"test_nllloss_NC_expanded",
47694767
"test_nllloss_NCd1_expanded",

0 commit comments

Comments
 (0)