22
22
import onnx
23
23
import onnxruntime as ort
24
24
import parameterized
25
- import pytest
26
25
import torch
27
26
from torch .testing ._internal import common_device_type
28
27
from torch .testing ._internal .opinfo import core as opinfo_core
33
32
from onnxscript .tests .function_libs .torch_lib import ops_test_common , ops_test_data
34
33
35
34
# All dtypes will be tested on the generated symbolic functions.
36
- # complex64 would be flattened to float32.
37
- # add new dtype in the tuple, and also add the new typpe in OPINFO_FUNCTION_TARGET_DTYPE right after the aten function you are testing
35
+ # complex64 will be flattened to float32.
38
36
TESTED_DTYPES = (
39
37
torch .float16 ,
40
38
torch .float32 ,
@@ -90,23 +88,16 @@ def _split_function_and_wrangler(
90
88
91
89
92
90
# according to https://pytorch.org/docs/stable/testing.html
93
- OPINFO_PRECISION_TABLE = {
91
+ OPINFO_PRECISION_TABLE : dict [ torch . dtype , tuple [ float , float ]] = {
94
92
# Tolerance value (rtol, atol)
95
93
# The current most relaxed values are for aten::matmul
96
94
torch .float32 : (3.7e-5 , 1.8e-4 ), # default is 1.3e-6, 1e-5
97
95
torch .float16 : (1e-3 , 1e-5 ), # default is 1e-3, 1e-5
98
96
}
99
97
100
98
101
- def _get_rtol_atol_by_dtype (dtype : torch .dtype ) -> tuple (Any , Any ):
102
- if dtype in OPINFO_PRECISION_TABLE :
103
- return OPINFO_PRECISION_TABLE [dtype ]
104
- return (None , None )
105
-
106
-
107
- def _dtype_is_supported_by_op (op_name : str , dtype : torch .dtype ) -> bool :
108
- dtype_list = ops_test_data .OPINFO_FUNCTION_TARGET_DTYPE .get (op_name )
109
- return dtype in dtype_list
99
+ def _get_rtol_atol_by_dtype (dtype : torch .dtype ) -> tuple [Any , Any ]:
100
+ return OPINFO_PRECISION_TABLE .get (dtype , (None , None ))
110
101
111
102
112
103
class TestFunctionValidity (unittest .TestCase ):
@@ -121,8 +112,8 @@ def test_all_script_functions_are_onnx_functions(self):
121
112
if not isinstance (func , onnxscript .OnnxFunction ):
122
113
raise AssertionError (
123
114
f"'{ func } ' is not an OnnxFunction. Was it decorated with '@torch_op'? "
124
- "If the function is trace_only, please move it to the "
125
- "'ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY' dict ."
115
+ "If the function is trace_only, please specify trace_only=True "
116
+ "in the TorchLibOpInfo entry ."
126
117
)
127
118
128
119
def test_all_trace_only_functions_are_not_onnx_functions (self ):
@@ -131,8 +122,8 @@ def test_all_trace_only_functions_are_not_onnx_functions(self):
131
122
if isinstance (func , onnxscript .OnnxFunction ):
132
123
raise AssertionError (
133
124
f"'{ func .name } ' is an OnnxFunction. "
134
- "If the function is not trace_only, please move it to the "
135
- "'ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED' dict ."
125
+ "If the function is not trace_only, please remove trace_only=True "
126
+ "in the TorchLibOpInfo entry ."
136
127
)
137
128
138
129
@parameterized .parameterized .expand (
@@ -318,9 +309,6 @@ def setUp(self) -> None:
318
309
def test_output_match_opinfo_ (
319
310
self , device : str , dtype : torch .dtype , op : opinfo_core .OpInfo
320
311
):
321
- if not _dtype_is_supported_by_op (op .name , dtype ):
322
- pytest .skip (reason = f"{ op .name } cannot support { dtype } " )
323
-
324
312
# Base test method for testing each op with the eager executor, used by instantiate_device_type_tests.
325
313
run_test_output_match (
326
314
self ,
@@ -341,7 +329,7 @@ def test_output_match_opinfo_(
341
329
[
342
330
info
343
331
for info in ops_test_data .OPS_DB
344
- if info .name in ops_test_data .COMPLEX_TESTED_OPS
332
+ if info .name in ops_test_data .COMPLEX_FUNCTION_MAPPING
345
333
],
346
334
allowed_dtypes = COMPLEX_TYPES ,
347
335
)
@@ -355,7 +343,7 @@ def test_complex_output_match_opinfo_(
355
343
dtype ,
356
344
op ,
357
345
ops_test_common .eager_executor ,
358
- ops_test_data .COMPLEX_FUNCTION_MAPPING_SCRIPTED ,
346
+ ops_test_data .COMPLEX_FUNCTION_MAPPING ,
359
347
)
360
348
361
349
@@ -383,9 +371,6 @@ def setUp(self) -> None:
383
371
def test_output_match_opinfo_ (
384
372
self , device : str , dtype : torch .dtype , op : opinfo_core .OpInfo
385
373
):
386
- if not _dtype_is_supported_by_op (op .name , dtype ):
387
- pytest .skip (reason = f"{ op .name } cannot support { dtype } " )
388
-
389
374
# Base test method for testing each op by running the full ONNX graph.
390
375
run_test_output_match (
391
376
self ,
@@ -406,7 +391,7 @@ def test_output_match_opinfo_(
406
391
[
407
392
info
408
393
for info in ops_test_data .OPS_DB
409
- if info .name in ops_test_data .COMPLEX_TESTED_OPS
394
+ if info .name in ops_test_data .COMPLEX_FUNCTION_MAPPING
410
395
],
411
396
allowed_dtypes = COMPLEX_TYPES ,
412
397
)
@@ -420,7 +405,7 @@ def test_complex_output_match_opinfo_(
420
405
dtype ,
421
406
op ,
422
407
ops_test_common .graph_executor ,
423
- ops_test_data .COMPLEX_FUNCTION_MAPPING_SCRIPTED ,
408
+ ops_test_data .COMPLEX_FUNCTION_MAPPING ,
424
409
)
425
410
426
411
0 commit comments