1
- # Copyright 2024 Arm Limited and/or its affiliates.
2
- # All rights reserved.
1
+ # Copyright 2024-2025 Arm Limited and/or its affiliates.
3
2
#
4
3
# This source code is licensed under the BSD-style license found in the
5
4
# LICENSE file in the root directory of this source tree.
6
5
import itertools
7
- import unittest
6
+
7
+ from typing import Tuple
8
8
9
9
import torch
10
10
from executorch .backends .arm .quantizer import is_annotated
11
- from executorch .backends .arm .test import common
12
- from executorch . backends . arm . test . tester . arm_tester import ArmTester
11
+ from executorch .backends .arm .test . tester . test_pipeline import TosaPipelineBI
12
+
13
13
from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
14
14
15
15
16
+ input_t1 = Tuple [torch .Tensor ] # Input x
17
+
18
+
16
19
class SingleOpModel (torch .nn .Module ):
17
20
def __init__ (self , op , example_input , ** op_kwargs ) -> None :
18
21
super ().__init__ ()
@@ -27,69 +30,74 @@ def example_inputs(self):
27
30
return self ._example_input
28
31
29
32
30
- class TestGenericAnnotator (unittest .TestCase ):
31
- def check_annotation (self , model ):
32
- tester = ArmTester (
33
- model ,
34
- model .example_inputs (),
35
- common .get_tosa_compile_spec ("TOSA-0.80+BI" ),
36
- )
37
- quant_model = tester .quantize ().get_artifact ()
38
- partitions = get_source_partitions (quant_model .graph , [model .op ])
39
- partitions = list (itertools .chain .from_iterable (partitions .values ()))
40
-
41
- assert len (partitions ) == 1
42
- partition = partitions [0 ]
43
- assert all (is_annotated (node ) for node in partition .nodes )
44
-
45
- def test_squeeze (self ):
46
- self .check_annotation (SingleOpModel (torch .squeeze , (torch .rand (8 , 8 , 1 ),)))
47
- self .check_annotation (SingleOpModel (torch .squeeze_copy , (torch .rand (8 , 8 , 1 ),)))
48
-
49
- def test_unsqueeze (self ):
50
- self .check_annotation (
51
- SingleOpModel (torch .unsqueeze , (torch .rand (8 , 8 ),), dim = 0 )
52
- )
53
- self .check_annotation (
54
- SingleOpModel (torch .unsqueeze_copy , (torch .rand (8 , 8 ),), dim = 0 )
55
- )
56
-
57
- def test_reshape (self ):
58
- self .check_annotation (
59
- SingleOpModel (torch .reshape , (torch .randn (8 , 8 ),), shape = (64 ,)),
60
- )
61
-
62
- def test_view (self ):
63
- self .check_annotation (
64
- SingleOpModel (torch .view_copy , (torch .randn (4 , 4 ),), size = (2 , 8 )),
65
- )
66
-
67
- def test_slice (self ):
68
- self .check_annotation (
69
- SingleOpModel (torch .slice_copy , (torch .randn (3 , 4 ),)),
70
- )
71
-
72
- def test_transpose (self ):
73
- self .check_annotation (
74
- SingleOpModel (torch .transpose , (torch .randn (2 , 3 ),), dim0 = 0 , dim1 = 1 ),
75
- )
76
- self .check_annotation (
77
- SingleOpModel (torch .transpose_copy , (torch .randn (2 , 3 ),), dim0 = 0 , dim1 = 1 ),
78
- )
79
-
80
- def test_tile (self ):
81
- self .check_annotation (
82
- SingleOpModel (torch .tile , (torch .randn (4 , 4 ),), dims = (2 ,)),
83
- )
84
-
85
- def test_flip (self ):
86
- self .check_annotation (
87
- SingleOpModel (torch .flip , (torch .randn (2 , 4 ),), dims = (0 , 1 )),
88
- )
89
-
90
- def test_concat (self ):
91
- self .check_annotation (
92
- SingleOpModel (
93
- torch .concatenate , ((torch .randn (2 , 3 ), torch .randn (2 , 3 )),), dim = 0
94
- ),
95
- )
33
+ def check_annotation (model ):
34
+ pipeline = TosaPipelineBI [input_t1 ](model , model .example_inputs (), [], [])
35
+ pipeline .pop_stage ("check_count.exir" )
36
+ pipeline .pop_stage ("run_method_and_compare_outputs" )
37
+ pipeline .run ()
38
+
39
+ artifact = pipeline .tester .get_artifact ("Quantize" )
40
+
41
+ partitions = get_source_partitions (artifact .graph , [model .op ])
42
+ partitions = list (itertools .chain .from_iterable (partitions .values ()))
43
+
44
+ assert len (partitions ) == 1
45
+ partition = partitions [0 ]
46
+ assert all (is_annotated (node ) for node in partition .nodes )
47
+
48
+
49
+ def test_squeeze ():
50
+ check_annotation (SingleOpModel (torch .squeeze , (torch .rand (8 , 8 , 1 ),)))
51
+ check_annotation (SingleOpModel (torch .squeeze_copy , (torch .rand (8 , 8 , 1 ),)))
52
+
53
+
54
+ def test_unsqueeze ():
55
+ check_annotation (SingleOpModel (torch .unsqueeze , (torch .rand (8 , 8 ),), dim = 0 ))
56
+ check_annotation (SingleOpModel (torch .unsqueeze_copy , (torch .rand (8 , 8 ),), dim = 0 ))
57
+
58
+
59
+ def test_reshape ():
60
+ check_annotation (
61
+ SingleOpModel (torch .reshape , (torch .randn (8 , 8 ),), shape = (64 ,)),
62
+ )
63
+
64
+
65
+ def test_view ():
66
+ check_annotation (
67
+ SingleOpModel (torch .view_copy , (torch .randn (4 , 4 ),), size = (2 , 8 )),
68
+ )
69
+
70
+
71
+ def test_slice ():
72
+ check_annotation (
73
+ SingleOpModel (torch .slice_copy , (torch .randn (3 , 4 ),)),
74
+ )
75
+
76
+
77
+ def test_transpose ():
78
+ check_annotation (
79
+ SingleOpModel (torch .transpose , (torch .randn (2 , 3 ),), dim0 = 0 , dim1 = 1 ),
80
+ )
81
+ check_annotation (
82
+ SingleOpModel (torch .transpose_copy , (torch .randn (2 , 3 ),), dim0 = 0 , dim1 = 1 ),
83
+ )
84
+
85
+
86
+ def test_tile ():
87
+ check_annotation (
88
+ SingleOpModel (torch .tile , (torch .randn (4 , 4 ),), dims = (2 ,)),
89
+ )
90
+
91
+
92
+ def test_flip ():
93
+ check_annotation (
94
+ SingleOpModel (torch .flip , (torch .randn (2 , 4 ),), dims = (0 , 1 )),
95
+ )
96
+
97
+
98
+ def test_concat ():
99
+ check_annotation (
100
+ SingleOpModel (
101
+ torch .concatenate , ((torch .randn (2 , 3 ), torch .randn (2 , 3 )),), dim = 0
102
+ ),
103
+ )
0 commit comments