Skip to content

Commit 2ec8678

Browse files
authored
Arm backend: Refactor Quantizer test to allow for TOSA 1.0 (#10905)
### Summary Update quantizer unit tests to use the new test infrastructure pipeline.
1 parent 6f015f6 commit 2ec8678

File tree

1 file changed

+79
-71
lines changed

1 file changed

+79
-71
lines changed
Lines changed: 79 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65
import itertools
7-
import unittest
6+
7+
from typing import Tuple
88

99
import torch
1010
from executorch.backends.arm.quantizer import is_annotated
11-
from executorch.backends.arm.test import common
12-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
11+
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI
12+
1313
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1414

1515

16+
input_t1 = Tuple[torch.Tensor] # Input x
17+
18+
1619
class SingleOpModel(torch.nn.Module):
1720
def __init__(self, op, example_input, **op_kwargs) -> None:
1821
super().__init__()
@@ -27,69 +30,74 @@ def example_inputs(self):
2730
return self._example_input
2831

2932

30-
class TestGenericAnnotator(unittest.TestCase):
31-
def check_annotation(self, model):
32-
tester = ArmTester(
33-
model,
34-
model.example_inputs(),
35-
common.get_tosa_compile_spec("TOSA-0.80+BI"),
36-
)
37-
quant_model = tester.quantize().get_artifact()
38-
partitions = get_source_partitions(quant_model.graph, [model.op])
39-
partitions = list(itertools.chain.from_iterable(partitions.values()))
40-
41-
assert len(partitions) == 1
42-
partition = partitions[0]
43-
assert all(is_annotated(node) for node in partition.nodes)
44-
45-
def test_squeeze(self):
46-
self.check_annotation(SingleOpModel(torch.squeeze, (torch.rand(8, 8, 1),)))
47-
self.check_annotation(SingleOpModel(torch.squeeze_copy, (torch.rand(8, 8, 1),)))
48-
49-
def test_unsqueeze(self):
50-
self.check_annotation(
51-
SingleOpModel(torch.unsqueeze, (torch.rand(8, 8),), dim=0)
52-
)
53-
self.check_annotation(
54-
SingleOpModel(torch.unsqueeze_copy, (torch.rand(8, 8),), dim=0)
55-
)
56-
57-
def test_reshape(self):
58-
self.check_annotation(
59-
SingleOpModel(torch.reshape, (torch.randn(8, 8),), shape=(64,)),
60-
)
61-
62-
def test_view(self):
63-
self.check_annotation(
64-
SingleOpModel(torch.view_copy, (torch.randn(4, 4),), size=(2, 8)),
65-
)
66-
67-
def test_slice(self):
68-
self.check_annotation(
69-
SingleOpModel(torch.slice_copy, (torch.randn(3, 4),)),
70-
)
71-
72-
def test_transpose(self):
73-
self.check_annotation(
74-
SingleOpModel(torch.transpose, (torch.randn(2, 3),), dim0=0, dim1=1),
75-
)
76-
self.check_annotation(
77-
SingleOpModel(torch.transpose_copy, (torch.randn(2, 3),), dim0=0, dim1=1),
78-
)
79-
80-
def test_tile(self):
81-
self.check_annotation(
82-
SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),
83-
)
84-
85-
def test_flip(self):
86-
self.check_annotation(
87-
SingleOpModel(torch.flip, (torch.randn(2, 4),), dims=(0, 1)),
88-
)
89-
90-
def test_concat(self):
91-
self.check_annotation(
92-
SingleOpModel(
93-
torch.concatenate, ((torch.randn(2, 3), torch.randn(2, 3)),), dim=0
94-
),
95-
)
33+
def check_annotation(model):
34+
pipeline = TosaPipelineBI[input_t1](model, model.example_inputs(), [], [])
35+
pipeline.pop_stage("check_count.exir")
36+
pipeline.pop_stage("run_method_and_compare_outputs")
37+
pipeline.run()
38+
39+
artifact = pipeline.tester.get_artifact("Quantize")
40+
41+
partitions = get_source_partitions(artifact.graph, [model.op])
42+
partitions = list(itertools.chain.from_iterable(partitions.values()))
43+
44+
assert len(partitions) == 1
45+
partition = partitions[0]
46+
assert all(is_annotated(node) for node in partition.nodes)
47+
48+
49+
def test_squeeze():
50+
check_annotation(SingleOpModel(torch.squeeze, (torch.rand(8, 8, 1),)))
51+
check_annotation(SingleOpModel(torch.squeeze_copy, (torch.rand(8, 8, 1),)))
52+
53+
54+
def test_unsqueeze():
55+
check_annotation(SingleOpModel(torch.unsqueeze, (torch.rand(8, 8),), dim=0))
56+
check_annotation(SingleOpModel(torch.unsqueeze_copy, (torch.rand(8, 8),), dim=0))
57+
58+
59+
def test_reshape():
60+
check_annotation(
61+
SingleOpModel(torch.reshape, (torch.randn(8, 8),), shape=(64,)),
62+
)
63+
64+
65+
def test_view():
66+
check_annotation(
67+
SingleOpModel(torch.view_copy, (torch.randn(4, 4),), size=(2, 8)),
68+
)
69+
70+
71+
def test_slice():
72+
check_annotation(
73+
SingleOpModel(torch.slice_copy, (torch.randn(3, 4),)),
74+
)
75+
76+
77+
def test_transpose():
78+
check_annotation(
79+
SingleOpModel(torch.transpose, (torch.randn(2, 3),), dim0=0, dim1=1),
80+
)
81+
check_annotation(
82+
SingleOpModel(torch.transpose_copy, (torch.randn(2, 3),), dim0=0, dim1=1),
83+
)
84+
85+
86+
def test_tile():
87+
check_annotation(
88+
SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),
89+
)
90+
91+
92+
def test_flip():
93+
check_annotation(
94+
SingleOpModel(torch.flip, (torch.randn(2, 4),), dims=(0, 1)),
95+
)
96+
97+
98+
def test_concat():
99+
check_annotation(
100+
SingleOpModel(
101+
torch.concatenate, ((torch.randn(2, 3), torch.randn(2, 3)),), dim=0
102+
),
103+
)

0 commit comments

Comments
 (0)