4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
# pyre-unsafe
7
- from typing import List
7
+ from typing import Any , List
8
8
9
9
import torch
10
10
11
- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
12
-
13
11
from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
14
12
get_input_qparams ,
15
13
get_output_qparams ,
@@ -36,14 +34,16 @@ def __init__(self, *args):
36
34
def _build_generic_avgpool2d (
37
35
self ,
38
36
node : torch .fx .Node ,
39
- tosa_graph : ts . TosaSerializer ,
37
+ tosa_graph : Any ,
40
38
inputs : List [TosaArg ],
41
39
output : TosaArg ,
42
40
input_zp : int ,
43
41
output_zp : int ,
44
- accumulator_type : ts . DType ,
42
+ accumulator_type : Any ,
45
43
) -> None :
46
44
45
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
46
+
47
47
input_tensor = inputs [0 ]
48
48
kernel_size_list = inputs [1 ].special
49
49
stride_size_list = inputs [2 ].special
@@ -79,10 +79,12 @@ def _build_generic_avgpool2d(
79
79
def define_node (
80
80
self ,
81
81
node : torch .fx .Node ,
82
- tosa_graph : ts . TosaSerializer ,
82
+ tosa_graph : Any ,
83
83
inputs : List [TosaArg ],
84
84
output : TosaArg ,
85
85
) -> None :
86
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
87
+
86
88
input_tensor = inputs [0 ]
87
89
assert input_tensor .dtype == ts .DType .INT8
88
90
@@ -110,10 +112,135 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
110
112
def define_node (
111
113
self ,
112
114
node : torch .fx .Node ,
113
- tosa_graph : ts . TosaSerializer ,
115
+ tosa_graph : Any ,
114
116
inputs : List [TosaArg ],
115
117
output : TosaArg ,
116
118
) -> None :
119
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
120
+
121
+ assert (
122
+ inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
123
+ ), "Only FP32 and INT8 supported"
124
+
125
+ if inputs [0 ].dtype == ts .DType .INT8 :
126
+ super ().define_node (node , tosa_graph , inputs , output )
127
+
128
+ if inputs [0 ].dtype == ts .DType .FP32 :
129
+ accumulator_type = ts .DType .FP32
130
+ # Initilize zero point to zero.
131
+ input_zp = 0
132
+ output_zp = 0
133
+
134
+ self ._build_generic_avgpool2d (
135
+ node , tosa_graph , inputs , output , input_zp , output_zp , accumulator_type
136
+ )
137
+
138
+
139
+ @register_node_visitor
140
+ class AvgPool2dVisitor (NodeVisitor ):
141
+ target = "aten.avg_pool2d.default"
142
+
143
+ tosa_specs = [
144
+ TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
145
+ ]
146
+
147
+ def __init__ (self , * args ):
148
+ super ().__init__ (* args )
149
+
150
+ def _build_generic_avgpool2d (
151
+ self ,
152
+ node : torch .fx .Node ,
153
+ tosa_graph : Any ,
154
+ inputs : List [TosaArg ],
155
+ output : TosaArg ,
156
+ input_zp : int ,
157
+ output_zp : int ,
158
+ accumulator_type : Any ,
159
+ ) -> None :
160
+
161
+ import serializer .tosa_serializer as ts # type: ignore
162
+
163
+ input_tensor = inputs [0 ]
164
+ kernel_size_list = inputs [1 ].special
165
+ stride_size_list = inputs [2 ].special
166
+
167
+ try :
168
+ pad_size_list = inputs [3 ].special
169
+ pad_size_list = [
170
+ pad_size_list [0 ],
171
+ pad_size_list [0 ],
172
+ pad_size_list [1 ],
173
+ pad_size_list [1 ],
174
+ ]
175
+ except IndexError :
176
+ pad_size_list = [0 , 0 , 0 , 0 ]
177
+
178
+ attr = ts .TosaSerializerAttribute ()
179
+ attr .AvgPool2dAttribute (
180
+ kernel = kernel_size_list ,
181
+ stride = stride_size_list ,
182
+ pad = pad_size_list ,
183
+ acc_type = accumulator_type ,
184
+ )
185
+ input_zp_tensor = tosa_graph .addConst (
186
+ shape = [1 ], dtype = output .dtype , vals = [input_zp ]
187
+ )
188
+ output_zp_tensor = tosa_graph .addConst (
189
+ shape = [1 ], dtype = output .dtype , vals = [output_zp ]
190
+ )
191
+
192
+ tosa_graph .addOperator (
193
+ ts .TosaOp .Op ().AVG_POOL2D ,
194
+ [input_tensor .name , input_zp_tensor .name , output_zp_tensor .name ],
195
+ [output .name ],
196
+ attr ,
197
+ )
198
+
199
+ def define_node (
200
+ self ,
201
+ node : torch .fx .Node ,
202
+ tosa_graph : Any ,
203
+ inputs : List [TosaArg ],
204
+ output : TosaArg ,
205
+ ) -> None :
206
+ import serializer .tosa_serializer as ts # type: ignore
207
+
208
+ input_tensor = inputs [0 ]
209
+ assert input_tensor .dtype == ts .DType .INT8
210
+
211
+ accumulator_type = ts .DType .INT32
212
+
213
+ input_qargs = get_input_qparams (node )
214
+ input_zp = input_qargs [0 ].zp
215
+
216
+ output_qargs = get_output_qparams (node )
217
+ output_zp = output_qargs [0 ].zp
218
+
219
+ self ._build_generic_avgpool2d (
220
+ node , tosa_graph , inputs , output , input_zp , output_zp , accumulator_type
221
+ )
222
+
223
+
224
+ @register_node_visitor
225
+ class AvgPool2dVisitor_FP (AvgPool2dVisitor ):
226
+ target = "aten.avg_pool2d.default"
227
+
228
+ tosa_specs = [
229
+ TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
230
+ ]
231
+
232
+ def __init__ (self , * args ):
233
+ super ().__init__ (* args )
234
+
235
+ def define_node (
236
+ self ,
237
+ node : torch .fx .Node ,
238
+ tosa_graph : Any ,
239
+ inputs : List [TosaArg ],
240
+ output : TosaArg ,
241
+ ) -> None :
242
+ import serializer .tosa_serializer as ts # type: ignore
243
+
117
244
assert (
118
245
inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
119
246
), "Only FP32 and INT8 supported"
0 commit comments