8
8
9
9
from abc import ABC , abstractmethod
10
10
from dataclasses import dataclass , field
11
- from typing import Callable , List , Optional , Tuple , Type , Union
11
+ from typing import List , Optional , Tuple , Union
12
12
13
13
import torch
14
14
from executorch .backends .cadence .aot .quantizer .utils import get_bias_qparams
@@ -47,17 +47,15 @@ class PartitionAnchors:
47
47
48
48
class QuantizationPattern (ABC ):
49
49
@abstractmethod
50
- def partition_types (
51
- self ,
52
- ) -> Union [List [Type [torch .nn .Module ]], List [Callable [..., torch .Tensor ]]]:
50
+ def partition_types (self ) -> list [OpOverload ]:
53
51
"""
54
- List of types to be passed to find_sequential_partitions .
52
+ List of types to be passed to find_sequential_partitions_aten .
55
53
"""
56
54
pass
57
55
58
56
@abstractmethod
59
57
def get_anchors (
60
- self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
58
+ self , gm : torch . fx .GraphModule , fused_partition : List [fx .GraphModule ]
61
59
) -> Optional [PartitionAnchors ]:
62
60
pass
63
61
@@ -71,8 +69,8 @@ def replacement_op(self) -> OpOverload:
71
69
72
70
73
71
class AddmmPattern (QuantizationPattern ):
74
- def partition_types (self ) -> List [Type [ torch . nn . Module ] ]:
75
- return [torch .addmm ]
72
+ def partition_types (self ) -> List [OpOverload ]:
73
+ return [torch .ops . aten . addmm . default ]
76
74
77
75
def get_anchors (
78
76
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -103,8 +101,8 @@ def replacement_op(self) -> OpOverload:
103
101
104
102
105
103
class BmmPattern (QuantizationPattern ):
106
- def partition_types (self ) -> List [Callable [..., torch . Tensor ] ]:
107
- return [torch .bmm ]
104
+ def partition_types (self ) -> List [OpOverload ]:
105
+ return [torch .ops . aten . bmm . default ]
108
106
109
107
def get_anchors (
110
108
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -123,8 +121,8 @@ def replacement_op(self) -> OpOverload:
123
121
124
122
125
123
class Conv1dPattern (QuantizationPattern ):
126
- def partition_types (self ) -> List [Type [ torch . nn . Module ] ]:
127
- return [torch .nn . Conv1d ]
124
+ def partition_types (self ) -> List [OpOverload ]:
125
+ return [torch .ops . aten . conv1d . default ]
128
126
129
127
def get_anchors (
130
128
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -161,8 +159,8 @@ def replacement_op(self) -> OpOverload:
161
159
162
160
163
161
class Conv2dPattern (QuantizationPattern ):
164
- def partition_types (self ) -> List [Type [ torch . nn . Module ] ]:
165
- return [torch .nn . Conv2d ]
162
+ def partition_types (self ) -> List [OpOverload ]:
163
+ return [torch .ops . aten . conv2d . default ]
166
164
167
165
def get_anchors (
168
166
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -199,32 +197,8 @@ def replacement_op(self) -> OpOverload:
199
197
200
198
201
199
class LayerNormPattern (QuantizationPattern ):
202
- def partition_types (self ) -> List [Type [torch .nn .Module ]]:
203
- return [torch .nn .LayerNorm ]
204
-
205
- def get_anchors (
206
- self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
207
- ) -> PartitionAnchors :
208
- layer_norm_node = fused_partition [0 ].nodes [- 1 ]
209
-
210
- # Weights and biases are used as fp32 by our kernel, so they are
211
- # passed in as others here along with the normalized shape.
212
- return PartitionAnchors (
213
- inputs = [(layer_norm_node , 0 )],
214
- weights = [],
215
- biases = [],
216
- # Ordering: normalized_shape, weights, bias
217
- others = [(layer_norm_node , 1 ), (layer_norm_node , 2 ), (layer_norm_node , 3 )],
218
- output = [(layer_norm_node ,)],
219
- )
220
-
221
- def replacement_op (self ) -> OpOverload :
222
- return torch .ops .cadence .quantized_layer_norm .default
223
-
224
-
225
- class LayerNormFunctionalPattern (QuantizationPattern ):
226
- def partition_types (self ) -> List [Callable [..., torch .Tensor ]]:
227
- return [torch .nn .functional .layer_norm ]
200
+ def partition_types (self ) -> List [OpOverload ]:
201
+ return [torch .ops .aten .layer_norm .default ]
228
202
229
203
def get_anchors (
230
204
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -257,8 +231,8 @@ def replacement_op(self) -> OpOverload:
257
231
258
232
259
233
class LinearPattern (QuantizationPattern ):
260
- def partition_types (self ) -> List [Type [ torch . nn . Module ] ]:
261
- return [torch .nn . Linear ]
234
+ def partition_types (self ) -> List [OpOverload ]:
235
+ return [torch .ops . aten . linear . default ]
262
236
263
237
def get_anchors (
264
238
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -294,47 +268,9 @@ def replacement_op(self) -> OpOverload:
294
268
return torch .ops .cadence .quantized_linear .default
295
269
296
270
297
- class LinearFunctionalPattern (QuantizationPattern ):
298
- def partition_types (self ) -> List [Callable [..., torch .Tensor ]]:
299
- return [torch .nn .functional .linear ]
300
-
301
- def get_anchors (
302
- self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
303
- ) -> PartitionAnchors :
304
- linear_node = fused_partition [0 ].nodes [- 1 ]
305
-
306
- bias_qspec = DerivedQuantizationSpec (
307
- derived_from = [
308
- (linear_node .args [0 ], linear_node ),
309
- (linear_node .args [1 ], linear_node ),
310
- ],
311
- derive_qparams_fn = get_bias_qparams ,
312
- dtype = torch .int32 ,
313
- quant_min = - (2 ** 31 ),
314
- quant_max = 2 ** 31 - 1 ,
315
- qscheme = torch .per_tensor_affine ,
316
- )
317
-
318
- # Keep bias empty if not supplied
319
- bias = []
320
- if len (linear_node .args ) > 2 and linear_node .args [2 ] is not None :
321
- bias = [(linear_node , 2 , bias_qspec )]
322
-
323
- return PartitionAnchors (
324
- inputs = [(linear_node , 0 )],
325
- weights = [(linear_node , 1 )],
326
- # pyre-fixme[6]: Incompatible parameter type
327
- biases = bias ,
328
- output = [(linear_node ,)],
329
- )
330
-
331
- def replacement_op (self ) -> OpOverload :
332
- return torch .ops .cadence .quantized_linear .default
333
-
334
-
335
271
class MatmulPattern (QuantizationPattern ):
336
- def partition_types (self ) -> List [Callable [..., torch . Tensor ] ]:
337
- return [torch .matmul ]
272
+ def partition_types (self ) -> List [OpOverload ]:
273
+ return [torch .ops . aten . matmul . default ]
338
274
339
275
def get_anchors (
340
276
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -353,8 +289,8 @@ def replacement_op(self) -> OpOverload:
353
289
354
290
355
291
class ReluPattern (QuantizationPattern ):
356
- def partition_types (self ) -> List [Type [ torch . nn . Module ] ]:
357
- return [torch .nn . ReLU ]
292
+ def partition_types (self ) -> List [OpOverload ]:
293
+ return [torch .ops . aten . relu . default ]
358
294
359
295
def get_anchors (
360
296
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
0 commit comments