4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ # pyre-strict
8
+
7
9
from abc import ABC , abstractmethod
8
10
from dataclasses import dataclass , field
9
- from typing import Any , Callable , List , Optional , Tuple , Type , Union
11
+ from typing import Callable , List , Optional , Tuple , Type , Union
10
12
11
13
import torch
12
14
from executorch .backends .cadence .aot .quantizer .utils import get_bias_qparams
13
15
14
16
from torch import fx
17
+ from torch ._ops import OpOverload
15
18
from torch .ao .quantization .quantizer import (
16
19
DerivedQuantizationSpec ,
17
20
SharedQuantizationSpec ,
@@ -44,18 +47,22 @@ class PartitionAnchors:
44
47
45
48
class QuantizationPattern (ABC ):
46
49
@abstractmethod
47
- def partition_types (self ):
50
+ def partition_types (
51
+ self ,
52
+ ) -> Union [List [Type [torch .nn .Module ]], List [Callable [..., torch .Tensor ]]]:
48
53
"""
49
54
List of types to be passed to find_sequential_partitions.
50
55
"""
51
56
pass
52
57
53
58
@abstractmethod
54
- def get_anchors (self , gm , fused_partition ) -> Optional [PartitionAnchors ]:
59
+ def get_anchors (
60
+ self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
61
+ ) -> Optional [PartitionAnchors ]:
55
62
pass
56
63
57
64
@abstractmethod
58
- def replacement_op (self ) -> Callable [..., Any ] :
65
+ def replacement_op (self ) -> OpOverload :
59
66
"""
60
67
Operator (most likely a custom one) that this partition should be fused into in
61
68
the backend. Refer to the QuantFusion pass for examples.
@@ -91,10 +98,30 @@ def get_anchors(
91
98
output = [(addmm_node ,)],
92
99
)
93
100
94
- def replacement_op (self ):
101
+ def replacement_op (self ) -> OpOverload :
95
102
return torch .ops .cadence .quantized_linear
96
103
97
104
105
+ class BmmPattern (QuantizationPattern ):
106
+ def partition_types (self ) -> List [Callable [..., torch .Tensor ]]:
107
+ return [torch .bmm ]
108
+
109
+ def get_anchors (
110
+ self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
111
+ ) -> PartitionAnchors :
112
+ bmm_node = fused_partition [0 ].nodes [- 1 ]
113
+
114
+ return PartitionAnchors (
115
+ inputs = [(bmm_node , 0 ), (bmm_node , 1 )],
116
+ weights = [],
117
+ biases = [],
118
+ output = [(bmm_node ,)],
119
+ )
120
+
121
+ def replacement_op (self ) -> OpOverload :
122
+ return torch .ops .cadence .quantized_matmul .default
123
+
124
+
98
125
class Conv1dPattern (QuantizationPattern ):
99
126
def partition_types (self ) -> List [Type [torch .nn .Module ]]:
100
127
return [torch .nn .Conv1d ]
@@ -129,7 +156,7 @@ def get_anchors(
129
156
output = [(conv1d_node ,)],
130
157
)
131
158
132
- def replacement_op (self ):
159
+ def replacement_op (self ) -> OpOverload :
133
160
return torch .ops .cadence .quantized_conv .default
134
161
135
162
@@ -167,15 +194,17 @@ def get_anchors(
167
194
output = [(conv2d_node ,)],
168
195
)
169
196
170
- def replacement_op (self ):
197
+ def replacement_op (self ) -> OpOverload :
171
198
return torch .ops .cadence .quantized_conv .default
172
199
173
200
174
201
class LayerNormPattern (QuantizationPattern ):
175
- def partition_types (self ):
202
+ def partition_types (self ) -> List [ Type [ torch . nn . Module ]] :
176
203
return [torch .nn .LayerNorm ]
177
204
178
- def get_anchors (self , gm , fused_partition ) -> PartitionAnchors :
205
+ def get_anchors (
206
+ self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
207
+ ) -> PartitionAnchors :
179
208
layer_norm_node = fused_partition [0 ].nodes [- 1 ]
180
209
181
210
# Weights and biases are used as fp32 by our kernel, so they are
@@ -189,15 +218,17 @@ def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
189
218
output = [(layer_norm_node ,)],
190
219
)
191
220
192
- def replacement_op (self ):
221
+ def replacement_op (self ) -> OpOverload :
193
222
return torch .ops .cadence .quantized_layer_norm .default
194
223
195
224
196
225
class LayerNormFunctionalPattern (QuantizationPattern ):
197
- def partition_types (self ):
226
+ def partition_types (self ) -> List [ Callable [..., torch . Tensor ]] :
198
227
return [torch .nn .functional .layer_norm ]
199
228
200
- def get_anchors (self , gm , fused_partition ) -> PartitionAnchors :
229
+ def get_anchors (
230
+ self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
231
+ ) -> PartitionAnchors :
201
232
layer_norm_node = fused_partition [0 ].nodes [- 1 ]
202
233
203
234
others = [(layer_norm_node , 1 )]
@@ -221,7 +252,7 @@ def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
221
252
output = [(layer_norm_node ,)],
222
253
)
223
254
224
- def replacement_op (self ):
255
+ def replacement_op (self ) -> OpOverload :
225
256
return torch .ops .cadence .quantized_layer_norm .default
226
257
227
258
@@ -259,12 +290,12 @@ def get_anchors(
259
290
output = [(linear_node ,)],
260
291
)
261
292
262
- def replacement_op (self ):
293
+ def replacement_op (self ) -> OpOverload :
263
294
return torch .ops .cadence .quantized_linear .default
264
295
265
296
266
297
class LinearFunctionalPattern (QuantizationPattern ):
267
- def partition_types (self ):
298
+ def partition_types (self ) -> List [ Callable [..., torch . Tensor ]] :
268
299
return [torch .nn .functional .linear ]
269
300
270
301
def get_anchors (
@@ -297,12 +328,12 @@ def get_anchors(
297
328
output = [(linear_node ,)],
298
329
)
299
330
300
- def replacement_op (self ):
331
+ def replacement_op (self ) -> OpOverload :
301
332
return torch .ops .cadence .quantized_linear .default
302
333
303
334
304
335
class MatmulPattern (QuantizationPattern ):
305
- def partition_types (self ):
336
+ def partition_types (self ) -> List [ Callable [..., torch . Tensor ]] :
306
337
return [torch .matmul ]
307
338
308
339
def get_anchors (
@@ -317,7 +348,7 @@ def get_anchors(
317
348
output = [(matmul_node ,)],
318
349
)
319
350
320
- def replacement_op (self ):
351
+ def replacement_op (self ) -> OpOverload :
321
352
return torch .ops .cadence .quantized_matmul .default
322
353
323
354
@@ -339,5 +370,5 @@ def get_anchors(
339
370
],
340
371
)
341
372
342
- def replacement_op (self ):
373
+ def replacement_op (self ) -> OpOverload :
343
374
return torch .ops .cadence .quantized_relu .default
0 commit comments