Skip to content

Commit 7be57df

Browse files
authored
Merge the various data maps and lists into a unified structure | test(torchlib) (#801)
- Create `TorchLibOpInfo` for putting everything about an op in a single place. - Fix `dtype_op_schema_compatible` to account for sequence types. - Fix `_aten_var_mean_dim_onnx` to support float16 - I removed the allow list in favor of the xfail and skips. In a next PR we will auto create new xfails and skips when we enable tests for new data types
1 parent da1bc00 commit 7be57df

File tree

5 files changed

+1068
-1584
lines changed

5 files changed

+1068
-1584
lines changed

noxfile.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
ONNX = "onnx==1.14.0"
3030
ONNX_RUNTIME = "onnxruntime==1.15.0"
31-
PYTORCH = "torch==2.0.0"
31+
PYTORCH = "torch==2.0.1"
3232
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
3333
"flatbuffers",
3434
"coloredlogs",

onnxscript/function_libs/torch_lib/ops/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6669,7 +6669,7 @@ def _aten_var_mean_dim_onnx(
66696669
if correction > 0.0:
66706670
self_shape = op.Shape(self)
66716671
dim_size = op.Gather(self_shape, dim, axis=0)
6672-
numel_float = op.Cast(op.ReduceProd(dim_size, keepdims=0), to=FLOAT.dtype)
6672+
numel_float = op.CastLike(op.ReduceProd(dim_size, keepdims=0), self)
66736673
mul = op.Mul(var, numel_float)
66746674
sub = op.Sub(numel_float, correction)
66756675
var = op.Div(mul, sub)

onnxscript/tests/function_libs/torch_lib/ops_test.py

+12-27
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import onnx
2323
import onnxruntime as ort
2424
import parameterized
25-
import pytest
2625
import torch
2726
from torch.testing._internal import common_device_type
2827
from torch.testing._internal.opinfo import core as opinfo_core
@@ -33,8 +32,7 @@
3332
from onnxscript.tests.function_libs.torch_lib import ops_test_common, ops_test_data
3433

3534
# All dtypes will be tested on the generated symbolic functions.
36-
# complex64 would be flattened to float32.
37-
# add new dtype in the tuple, and also add the new typpe in OPINFO_FUNCTION_TARGET_DTYPE right after the aten function you are testing
35+
# complex64 will be flattened to float32.
3836
TESTED_DTYPES = (
3937
torch.float16,
4038
torch.float32,
@@ -90,23 +88,16 @@ def _split_function_and_wrangler(
9088

9189

9290
# according to https://pytorch.org/docs/stable/testing.html
93-
OPINFO_PRECISION_TABLE = {
91+
OPINFO_PRECISION_TABLE: dict[torch.dtype, tuple[float, float]] = {
9492
# Tolerance value (rtol, atol)
9593
# The current most relaxed values are for aten::matmul
9694
torch.float32: (3.7e-5, 1.8e-4), # default is 1.3e-6, 1e-5
9795
torch.float16: (1e-3, 1e-5), # default is 1e-3, 1e-5
9896
}
9997

10098

101-
def _get_rtol_atol_by_dtype(dtype: torch.dtype) -> tuple(Any, Any):
102-
if dtype in OPINFO_PRECISION_TABLE:
103-
return OPINFO_PRECISION_TABLE[dtype]
104-
return (None, None)
105-
106-
107-
def _dtype_is_supported_by_op(op_name: str, dtype: torch.dtype) -> bool:
108-
dtype_list = ops_test_data.OPINFO_FUNCTION_TARGET_DTYPE.get(op_name)
109-
return dtype in dtype_list
99+
def _get_rtol_atol_by_dtype(dtype: torch.dtype) -> tuple[Any, Any]:
100+
return OPINFO_PRECISION_TABLE.get(dtype, (None, None))
110101

111102

112103
class TestFunctionValidity(unittest.TestCase):
@@ -121,8 +112,8 @@ def test_all_script_functions_are_onnx_functions(self):
121112
if not isinstance(func, onnxscript.OnnxFunction):
122113
raise AssertionError(
123114
f"'{func}' is not an OnnxFunction. Was it decorated with '@torch_op'? "
124-
"If the function is trace_only, please move it to the "
125-
"'ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY' dict."
115+
"If the function is trace_only, please specify trace_only=True "
116+
"in the TorchLibOpInfo entry."
126117
)
127118

128119
def test_all_trace_only_functions_are_not_onnx_functions(self):
@@ -131,8 +122,8 @@ def test_all_trace_only_functions_are_not_onnx_functions(self):
131122
if isinstance(func, onnxscript.OnnxFunction):
132123
raise AssertionError(
133124
f"'{func.name}' is an OnnxFunction. "
134-
"If the function is not trace_only, please move it to the "
135-
"'ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED' dict."
125+
"If the function is not trace_only, please remove trace_only=True "
126+
"in the TorchLibOpInfo entry."
136127
)
137128

138129
@parameterized.parameterized.expand(
@@ -318,9 +309,6 @@ def setUp(self) -> None:
318309
def test_output_match_opinfo_(
319310
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
320311
):
321-
if not _dtype_is_supported_by_op(op.name, dtype):
322-
pytest.skip(reason=f"{op.name} cannot support {dtype}")
323-
324312
# Base test method for testing each op with the eager executor, used by instantiate_device_type_tests.
325313
run_test_output_match(
326314
self,
@@ -341,7 +329,7 @@ def test_output_match_opinfo_(
341329
[
342330
info
343331
for info in ops_test_data.OPS_DB
344-
if info.name in ops_test_data.COMPLEX_TESTED_OPS
332+
if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING
345333
],
346334
allowed_dtypes=COMPLEX_TYPES,
347335
)
@@ -355,7 +343,7 @@ def test_complex_output_match_opinfo_(
355343
dtype,
356344
op,
357345
ops_test_common.eager_executor,
358-
ops_test_data.COMPLEX_FUNCTION_MAPPING_SCRIPTED,
346+
ops_test_data.COMPLEX_FUNCTION_MAPPING,
359347
)
360348

361349

@@ -383,9 +371,6 @@ def setUp(self) -> None:
383371
def test_output_match_opinfo_(
384372
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
385373
):
386-
if not _dtype_is_supported_by_op(op.name, dtype):
387-
pytest.skip(reason=f"{op.name} cannot support {dtype}")
388-
389374
# Base test method for testing each op by running the full ONNX graph.
390375
run_test_output_match(
391376
self,
@@ -406,7 +391,7 @@ def test_output_match_opinfo_(
406391
[
407392
info
408393
for info in ops_test_data.OPS_DB
409-
if info.name in ops_test_data.COMPLEX_TESTED_OPS
394+
if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING
410395
],
411396
allowed_dtypes=COMPLEX_TYPES,
412397
)
@@ -420,7 +405,7 @@ def test_complex_output_match_opinfo_(
420405
dtype,
421406
op,
422407
ops_test_common.graph_executor,
423-
ops_test_data.COMPLEX_FUNCTION_MAPPING_SCRIPTED,
408+
ops_test_data.COMPLEX_FUNCTION_MAPPING,
424409
)
425410

426411

onnxscript/tests/function_libs/torch_lib/ops_test_common.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,9 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -
424424
first_input_type_name = schema.inputs[0].type_str
425425
# Find the type constraint for the first input by matching the parameter name
426426
first_input_type_constraint = next(
427-
(x for x in schema.type_constraints if x.type_param_str == first_input_type_name), None
427+
# Here we consider seq(tensor(float)) compatible with tensor(float) as well
428+
(x for x in schema.type_constraints if first_input_type_name in x.type_param_str),
429+
None,
428430
)
429431
assert first_input_type_constraint is not None
430432
allowed_type_strs = first_input_type_constraint.allowed_type_strs

0 commit comments

Comments
 (0)