Skip to content

Commit 1b99762

Browse files
authored
Merge branch 'main' into unsupported-u55
2 parents 151e1e3 + e22b21a commit 1b99762

File tree

3 files changed

+276
-120
lines changed

3 files changed

+276
-120
lines changed

backends/arm/test/ops/test_scalars.py

Lines changed: 164 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@
55

66
import unittest
77

8+
from typing import Tuple
9+
10+
import common
811
import torch
912

10-
from executorch.backends.arm.test import common
11-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
12-
from parameterized import parameterized
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
TransformAnnotationPassPipeline,
17+
)
1318

1419
"""
1520
Summary of non-working cases.
@@ -24,6 +29,7 @@
2429
# MLETORCH-408
2530
Sub or inplace-sub with an integer input.
2631
"""
32+
input_t1 = Tuple[torch.Tensor, torch.scalar_tensor] # Input x, Input y
2733

2834

2935
class TestScalars(unittest.TestCase):
@@ -92,112 +98,160 @@ def forward(self, x):
9298
x -= 10
9399
return x
94100

95-
# Inplace ops end with '_' (from aten naming)
96-
ops = [
97-
("Add", Add()),
98-
("Sub", Sub()),
99-
("Mul", Mul()),
100-
("Div", Div()),
101-
("Add_", AddInplace()),
102-
("Sub_", SubInplace()),
103-
("Mul_", MulInplace()),
104-
("Div_", DivInplace()),
105-
("MulScalar", MulScalar()),
106-
("DivScalar", DivScalar()),
107-
("AddScalar", AddScalar()),
108-
("SubScalar", SubScalar()),
109-
]
110-
111-
const_ops = [("Add", AddConst())]
112-
113-
dtypes = [("int", 3), ("float", 3.0)]
114-
sizes = [("r1", (1)), ("r4", (2, 4, 5, 3))]
115-
116-
# Create combinations of tests
117-
tensor_scalar_tests = []
118-
for op in ops:
119-
for dtype in dtypes:
120-
for size in sizes:
121-
test_name = f"{op[0]}_{dtype[0]}_{size[0]}"
122-
tensor = torch.rand(size[1])
123-
scalar = dtype[1]
124-
tensor_scalar_tests.append((test_name + "_ts", op[1], tensor, scalar))
125-
126-
# Don't add (scalar, tensor) test case for .Scalar ops.
127-
if op[0][-6:] == "Scalar":
128-
continue
129-
130-
tensor_scalar_tests.append((test_name + "_st", op[1], scalar, tensor))
131-
132-
tensor_const_tests = []
133-
for op in const_ops:
101+
102+
# Inplace ops end with '_' (from aten naming)
103+
ops = [
104+
("Add", TestScalars.Add()),
105+
("Sub", TestScalars.Sub()),
106+
("Mul", TestScalars.Mul()),
107+
("Div", TestScalars.Div()),
108+
("Add_", TestScalars.AddInplace()),
109+
("Sub_", TestScalars.SubInplace()),
110+
("Mul_", TestScalars.MulInplace()),
111+
("Div_", TestScalars.DivInplace()),
112+
("MulScalar", TestScalars.MulScalar()),
113+
("DivScalar", TestScalars.DivScalar()),
114+
("AddScalar", TestScalars.AddScalar()),
115+
("SubScalar", TestScalars.SubScalar()),
116+
]
117+
118+
const_ops = [("Add", TestScalars.AddConst())]
119+
120+
dtypes = [("int", 3), ("float", 3.0)]
121+
sizes = [("r1", (1)), ("r4", (2, 4, 5, 3))]
122+
123+
# Create combinations of tests
124+
tensor_scalar_tests = {}
125+
for op in ops:
126+
for dtype in dtypes:
134127
for size in sizes:
135-
test_name = f"{op[0]}_{size[0]}"
128+
test_name = f"{op[0]}_{dtype[0]}_{size[0]}"
136129
tensor = torch.rand(size[1])
137-
tensor_const_tests.append((test_name, op[1], tensor))
138-
139-
def _test_add_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: tuple):
140-
(
141-
ArmTester(
142-
module,
143-
example_inputs=test_data,
144-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
145-
)
146-
.export()
147-
.to_edge()
148-
.partition()
149-
.to_executorch()
150-
.run_method_and_compare_outputs(inputs=test_data)
151-
)
152-
153-
def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple):
154-
(
155-
ArmTester(
156-
module,
157-
example_inputs=test_data,
158-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
159-
)
160-
.quantize()
161-
.export()
162-
.to_edge()
163-
.partition()
164-
.to_executorch()
165-
.run_method_and_compare_outputs(inputs=test_data)
166-
)
167-
168-
@parameterized.expand(tensor_scalar_tests)
169-
def test_MI(self, test_name: str, op: torch.nn.Module, x, y):
170-
expected_exception = None
171-
if any(token in test_name for token in ("Sub_int", "Sub__int")):
172-
expected_exception = AssertionError
173-
if test_name.endswith("_st"):
174-
expected_exception = AttributeError
175-
176-
if expected_exception:
177-
with self.assertRaises(
178-
expected_exception, msg=f"Test {test_name} is expected to fail."
179-
):
180-
self._test_add_tosa_MI_pipeline(op, (x, y))
181-
return
182-
183-
self._test_add_tosa_MI_pipeline(op, (x, y))
184-
185-
# op(Scalar float, tensor) works if the scalar is constant.
186-
@parameterized.expand(tensor_const_tests)
187-
def test_MI_const(self, test_name: str, op: torch.nn.Module, x):
188-
self._test_add_tosa_MI_pipeline(op, (x,))
189-
190-
@parameterized.expand(tensor_scalar_tests)
191-
def test_BI(self, test_name: str, op: torch.nn.Module, x, y):
192-
self._test_add_tosa_BI_pipeline(op, (x, y))
193-
194-
# op(Scalar float, tensor) works if the scalar is constant.
195-
@parameterized.expand(tensor_const_tests)
196-
def test_BI_const(self, test_name: str, op: torch.nn.Module, x):
197-
self._test_add_tosa_BI_pipeline(op, (x,))
198-
199-
def test_shift_sub_inplace_tosa_MI(self):
200-
self._test_add_tosa_MI_pipeline(self.ShiftInplaceSub(), (torch.IntTensor(5),))
201-
202-
def test_shift_sub_inplace_tosa_BI(self):
203-
self._test_add_tosa_BI_pipeline(self.ShiftInplaceSub(), (torch.IntTensor(5),))
130+
scalar = dtype[1]
131+
tensor_scalar_tests[test_name + "_ts"] = (op[1], tensor, scalar)
132+
# Don't add (scalar, tensor) test case for .Scalar ops.
133+
if op[0][-6:] == "Scalar":
134+
continue
135+
136+
tensor_scalar_tests[test_name + "_st"] = (op[1], scalar, tensor)
137+
138+
tensor_const_tests = {}
139+
for op in const_ops:
140+
for size in sizes:
141+
test_name = f"{op[0]}_{size[0]}"
142+
tensor = torch.rand(size[1])
143+
tensor_const_tests[test_name] = (op[1], tensor)
144+
145+
146+
def _test_add_tosa_MI_pipeline(module: torch.nn.Module, test_data: tuple):
147+
pipeline = TosaPipelineMI[input_t1](module, test_data, aten_op=[], exir_op=[])
148+
pipeline.run()
149+
150+
151+
def _test_add_tosa_BI_pipeline(
152+
module: torch.nn.Module, test_data: tuple, check_quant_nodes=True
153+
):
154+
pipeline = TosaPipelineBI[input_t1](module, test_data, aten_op=[], exir_op=[])
155+
if not check_quant_nodes:
156+
pipeline.pop_stage("check.quant_nodes")
157+
pipeline.run()
158+
159+
160+
fail_str = "MLETORCH-408: Arithmetic ops can't handle scalars first for MI"
161+
MI_xfails = {
162+
"Add_int_r1_st": fail_str,
163+
"Add_int_r4_st": fail_str,
164+
"Add_float_r1_st": fail_str,
165+
"Add_float_r4_st": fail_str,
166+
"Sub_int_r1_ts": fail_str,
167+
"Sub_int_r1_st": fail_str,
168+
"Sub_int_r4_ts": fail_str,
169+
"Sub_int_r4_st": fail_str,
170+
"Sub_float_r1_st": fail_str,
171+
"Sub_float_r4_st": fail_str,
172+
"Mul_int_r1_st": fail_str,
173+
"Mul_int_r4_st": fail_str,
174+
"Mul_float_r1_st": fail_str,
175+
"Mul_float_r4_st": fail_str,
176+
"Div_int_r1_st": fail_str,
177+
"Div_int_r4_st": fail_str,
178+
"Div_float_r1_st": fail_str,
179+
"Div_float_r4_st": fail_str,
180+
"Add__int_r1_st": fail_str,
181+
"Add__float_r1_st": fail_str,
182+
"Add__float_r4_st": fail_str,
183+
"Add__int_r4_st": fail_str,
184+
"Sub__int_r1_ts": fail_str,
185+
"Sub__int_r1_st": fail_str,
186+
"Sub__int_r4_ts": fail_str,
187+
"Sub__int_r4_st": fail_str,
188+
"Sub__float_r1_st": fail_str,
189+
"Sub__float_r4_st": fail_str,
190+
"Mul__int_r1_st": fail_str,
191+
"Mul__int_r4_st": fail_str,
192+
"Mul__float_r1_st": fail_str,
193+
"Mul__float_r4_st": fail_str,
194+
"Div__int_r1_st": fail_str,
195+
"Div__int_r4_st": fail_str,
196+
"Div__float_r1_st": fail_str,
197+
"Div__float_r4_st": fail_str,
198+
}
199+
200+
201+
@common.parametrize("tensor_scalar_tests", tensor_scalar_tests, MI_xfails)
202+
def test_MI(tensor_scalar_tests: list):
203+
op, x, y = tensor_scalar_tests
204+
_test_add_tosa_MI_pipeline(op, (x, y))
205+
206+
207+
def _test_passes_tosa_BI_pipeline(module: torch.nn.Module, test_data: tuple):
208+
pipeline = TransformAnnotationPassPipeline[input_t1](
209+
module, test_data, tosa_version="TOSA-0.80+BI"
210+
)
211+
pipeline.run()
212+
213+
214+
fail_str = "MLETORCH-770: Numerical issues on Div Scalar."
215+
passes_xfails = {
216+
"Div__int_r1_ts": fail_str,
217+
"Div__int_r4_ts": fail_str,
218+
"Div__float_r1_ts": fail_str,
219+
"Div__float_r4_ts": fail_str,
220+
}
221+
222+
223+
@common.parametrize("tensor_scalar_tests", tensor_scalar_tests, passes_xfails)
224+
def test_passes_BI(tensor_scalar_tests: list):
225+
op, x, y = tensor_scalar_tests
226+
_test_passes_tosa_BI_pipeline(op, (x, y))
227+
228+
229+
# op(Scalar float, tensor) works if the scalar is constant.
230+
@common.parametrize("tensor_const_tests", tensor_const_tests)
231+
def test_MI_const(tensor_const_tests: list):
232+
op, x = tensor_const_tests
233+
_test_add_tosa_MI_pipeline(op, (x,))
234+
235+
236+
@common.parametrize("tensor_scalar_tests", tensor_scalar_tests)
237+
def test_BI(tensor_scalar_tests: list):
238+
op, x, y = tensor_scalar_tests
239+
_test_add_tosa_BI_pipeline(op, (x, y))
240+
241+
242+
# op(Scalar float, tensor) works if the scalar is constant.
243+
@common.parametrize("tensor_const_tests", tensor_const_tests)
244+
def test_BI_const(tensor_const_tests: list):
245+
op, x = tensor_const_tests
246+
_test_add_tosa_BI_pipeline(op, (x,))
247+
248+
249+
def test_shift_sub_inplace_tosa_MI():
250+
_test_add_tosa_MI_pipeline(TestScalars.ShiftInplaceSub(), (torch.IntTensor(5),))
251+
252+
253+
# Do not check for quant nodes in the graph for rshift.
254+
def test_shift_sub_inplace_tosa_BI():
255+
_test_add_tosa_BI_pipeline(
256+
TestScalars.ShiftInplaceSub(), (torch.IntTensor(5),), check_quant_nodes=False
257+
)

0 commit comments

Comments
 (0)