1
1
# mypy: allow-untyped-defs
2
+ from typing import Any , Callable , Dict , List , Optional , Set , Union
3
+
2
4
import torch
3
- import torch .nn as nn
4
5
import torch .ao .nn .quantized as nnq
5
6
import torch .ao .nn .quantized .dynamic as nnqd
7
+ import torch .nn as nn
6
8
from torch .ao .quantization import prepare
7
- from typing import Dict , List , Optional , Any , Union , Callable , Set
8
-
9
9
from torch .ao .quantization .quantization_mappings import (
10
10
get_default_compare_output_module_list ,
11
11
)
12
12
13
+
13
14
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
14
15
nnqd .Linear ,
15
16
nnq .Linear ,
19
20
20
21
21
22
def _find_match (
22
- str_list : Union [Dict [str , Any ], List [str ]], key_str : str ,
23
+ str_list : Union [Dict [str , Any ], List [str ]],
24
+ key_str : str ,
23
25
postfix : str ,
24
26
) -> Optional [str ]:
25
27
split_str = key_str .split ("." )
@@ -120,7 +122,8 @@ def compare_weights(
120
122
121
123
122
124
def _get_logger_dict_helper (
123
- mod : nn .Module , target_dict : Dict [str , Any ],
125
+ mod : nn .Module ,
126
+ target_dict : Dict [str , Any ],
124
127
prefix : str = "" ,
125
128
) -> None :
126
129
r"""This is the helper function for get_logger_dict
@@ -168,8 +171,7 @@ def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
168
171
169
172
170
173
class Logger (nn .Module ):
171
- r"""Base class for stats logging
172
- """
174
+ r"""Base class for stats logging"""
173
175
174
176
def __init__ (self ):
175
177
super ().__init__ ()
@@ -180,8 +182,10 @@ def __init__(self):
180
182
self .dtype = torch .quint8
181
183
182
184
def forward (self , x ):
185
+ # fmt: off
183
186
"""
184
187
""" # blank docblock to make autodoc happy
188
+ # fmt: on
185
189
pass
186
190
187
191
@@ -196,8 +200,10 @@ def __init__(self):
196
200
self .stats ["quantized" ] = []
197
201
198
202
def forward (self , x , y ):
203
+ # fmt: off
199
204
"""
200
205
""" # blank docblock to make autodoc happy
206
+ # fmt: on
201
207
if len (x ) > 1 :
202
208
x = x [0 ]
203
209
if len (y ) > 1 :
@@ -207,17 +213,17 @@ def forward(self, x, y):
207
213
208
214
209
215
class OutputLogger (Logger ):
210
- r"""Class used to log the outputs of the module
211
- """
216
+ r"""Class used to log the outputs of the module"""
212
217
213
218
def __init__ (self ):
214
219
super ().__init__ ()
215
220
self .stats ["tensor_val" ] = []
216
221
217
-
218
222
def forward (self , x ):
223
+ # fmt: off
219
224
"""
220
225
""" # blank docblock to make autodoc happy
226
+ # fmt: on
221
227
self .stats ["tensor_val" ].append (x )
222
228
return x
223
229
@@ -256,8 +262,10 @@ def __init__(self, q_module, float_module, logger_cls):
256
262
self .logger = logger_cls ()
257
263
258
264
def forward (self , * x ) -> torch .Tensor :
265
+ # fmt: off
259
266
"""
260
267
""" # blank docblock to make autodoc happy
268
+ # fmt: on
261
269
xl = _convert_tuple_to_list (x )
262
270
output = self .orig_module (* xl )
263
271
xl_float = _dequantize_tensor_list (xl )
@@ -266,8 +274,10 @@ def forward(self, *x) -> torch.Tensor:
266
274
return output
267
275
268
276
def add (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
277
+ # fmt: off
269
278
"""
270
279
""" # blank docblock to make autodoc happy
280
+ # fmt: on
271
281
output = self .orig_module .add (x , y )
272
282
x = x .dequantize ()
273
283
y = y .dequantize ()
@@ -276,17 +286,21 @@ def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
276
286
return output
277
287
278
288
def add_scalar (self , x : torch .Tensor , y : float ) -> torch .Tensor :
289
+ # fmt: off
279
290
"""
280
291
""" # blank docblock to make autodoc happy
292
+ # fmt: on
281
293
output = self .orig_module .add_scalar (x , y )
282
294
x = x .dequantize ()
283
295
shadow_output = self .shadow_module .add_scalar (x , y )
284
296
self .logger (output , shadow_output )
285
297
return output
286
298
287
299
def mul (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
300
+ # fmt: off
288
301
"""
289
302
""" # blank docblock to make autodoc happy
303
+ # fmt: on
290
304
output = self .orig_module .mul (x , y )
291
305
x = x .dequantize ()
292
306
y = y .dequantize ()
@@ -295,26 +309,32 @@ def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
295
309
return output
296
310
297
311
def mul_scalar (self , x : torch .Tensor , y : float ) -> torch .Tensor :
312
+ # fmt: off
298
313
"""
299
314
""" # blank docblock to make autodoc happy
315
+ # fmt: on
300
316
output = self .orig_module .mul_scalar (x , y )
301
317
x = x .dequantize ()
302
318
shadow_output = self .shadow_module .mul_scalar (x , y )
303
319
self .logger (output , shadow_output )
304
320
return output
305
321
306
322
def cat (self , x : List [torch .Tensor ], dim : int = 0 ) -> torch .Tensor :
323
+ # fmt: off
307
324
"""
308
325
""" # blank docblock to make autodoc happy
326
+ # fmt: on
309
327
output = self .orig_module .cat (x , dim )
310
328
x = [y .dequantize () for y in x ]
311
329
shadow_output = self .shadow_module .cat (x , dim )
312
330
self .logger (output , shadow_output )
313
331
return output
314
332
315
333
def add_relu (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
334
+ # fmt: off
316
335
"""
317
336
""" # blank docblock to make autodoc happy
337
+ # fmt: on
318
338
output = self .orig_module .add_relu (x , y )
319
339
x = x .dequantize ()
320
340
y = y .dequantize ()
@@ -324,8 +344,10 @@ def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
324
344
325
345
326
346
def prepare_model_with_stubs (
327
- float_module : nn .Module , q_module : nn .Module ,
328
- module_swap_list : Set [type ], logger_cls : Callable ,
347
+ float_module : nn .Module ,
348
+ q_module : nn .Module ,
349
+ module_swap_list : Set [type ],
350
+ logger_cls : Callable ,
329
351
) -> None :
330
352
r"""Prepare the model by attaching the float module to its matching quantized
331
353
module as the shadow if the float module type is in module_swap_list.
@@ -343,15 +365,16 @@ def prepare_model_with_stubs(
343
365
logger_cls: type of logger to be used in shadow module to process the outputs of
344
366
quantized module and its float shadow module
345
367
"""
346
- torch ._C ._log_api_usage_once ("quantization_api._numeric_suite.prepare_model_with_stubs" )
368
+ torch ._C ._log_api_usage_once (
369
+ "quantization_api._numeric_suite.prepare_model_with_stubs"
370
+ )
347
371
348
372
float_module_children = {}
349
373
for name , mod in float_module .named_children ():
350
374
float_module_children [name ] = mod
351
375
352
376
reassign = {}
353
377
for name , mod in q_module .named_children ():
354
-
355
378
if name not in float_module_children :
356
379
continue
357
380
@@ -362,23 +385,28 @@ def prepare_model_with_stubs(
362
385
363
386
# Insert shadow module only if the module is not of the same type as
364
387
# the floating point module
365
- if type (float_mod ) in module_swap_list and not _is_identical_module_type (mod , float_mod ):
388
+ if type (float_mod ) in module_swap_list and not _is_identical_module_type (
389
+ mod , float_mod
390
+ ):
366
391
reassign [name ] = Shadow (mod , float_mod , logger_cls )
367
392
368
393
for key , value in reassign .items ():
369
394
q_module ._modules [key ] = value
370
395
396
+
371
397
def _is_identical_module_type (mod1 , mod2 ):
372
398
# Compare if two modules have the same dtype
373
399
mod1_module_types = [type (mod ) for mod in mod1 .modules ()]
374
400
mod2_module_types = [type (mod ) for mod in mod2 .modules ()]
375
401
return mod1_module_types == mod2_module_types
376
402
377
403
378
-
379
404
def compare_model_stub (
380
- float_model : nn .Module , q_model : nn .Module , module_swap_list : Set [type ],
381
- * data , logger_cls = ShadowLogger
405
+ float_model : nn .Module ,
406
+ q_model : nn .Module ,
407
+ module_swap_list : Set [type ],
408
+ * data ,
409
+ logger_cls = ShadowLogger ,
382
410
) -> Dict [str , Dict ]:
383
411
r"""Compare quantized module in a model with its floating point counterpart,
384
412
feeding both of them the same input. Return a dict with key corresponding to
@@ -419,7 +447,8 @@ def compare_model_stub(
419
447
420
448
421
449
def get_matching_activations (
422
- float_module : nn .Module , q_module : nn .Module ,
450
+ float_module : nn .Module ,
451
+ q_module : nn .Module ,
423
452
) -> Dict [str , Dict [str , torch .Tensor ]]:
424
453
r"""Find the matching activation between float and quantized modules.
425
454
@@ -432,7 +461,9 @@ def get_matching_activations(
432
461
entry being a dictionary with two keys 'float' and 'quantized', containing
433
462
the matching float and quantized activations
434
463
"""
435
- torch ._C ._log_api_usage_once ("quantization_api._numeric_suite.get_matching_activations" )
464
+ torch ._C ._log_api_usage_once (
465
+ "quantization_api._numeric_suite.get_matching_activations"
466
+ )
436
467
float_dict = get_logger_dict (float_module )
437
468
quantized_dict = get_logger_dict (q_module )
438
469
act_dict : Dict [str , Dict ] = {}
@@ -451,7 +482,7 @@ def prepare_model_outputs(
451
482
float_module : nn .Module ,
452
483
q_module : nn .Module ,
453
484
logger_cls = OutputLogger ,
454
- allow_list = None
485
+ allow_list = None ,
455
486
) -> None :
456
487
r"""Prepare the model by attaching the logger to both float module
457
488
and quantized module if they are in the allow_list.
@@ -462,20 +493,24 @@ def prepare_model_outputs(
462
493
logger_cls: type of logger to be attached to float_module and q_module
463
494
allow_list: list of module types to attach logger
464
495
"""
465
- torch ._C ._log_api_usage_once ("quantization_api._numeric_suite.prepare_model_outputs" )
496
+ torch ._C ._log_api_usage_once (
497
+ "quantization_api._numeric_suite.prepare_model_outputs"
498
+ )
466
499
if allow_list is None :
467
500
allow_list = get_default_compare_output_module_list ()
468
501
469
502
qconfig_debug = torch .ao .quantization .QConfig (activation = logger_cls , weight = None )
470
503
float_module .qconfig = qconfig_debug # type: ignore[assignment]
471
- prepare (float_module , inplace = True , allow_list = allow_list , prepare_custom_config_dict = {})
504
+ prepare (
505
+ float_module , inplace = True , allow_list = allow_list , prepare_custom_config_dict = {}
506
+ )
472
507
q_module .qconfig = qconfig_debug # type: ignore[assignment]
473
508
prepare (
474
509
q_module ,
475
510
inplace = True ,
476
511
allow_list = allow_list ,
477
512
observer_non_leaf_module_list = NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST ,
478
- prepare_custom_config_dict = {}
513
+ prepare_custom_config_dict = {},
479
514
)
480
515
481
516
@@ -484,7 +519,7 @@ def compare_model_outputs(
484
519
q_model : nn .Module ,
485
520
* data ,
486
521
logger_cls = OutputLogger ,
487
- allow_list = None
522
+ allow_list = None ,
488
523
) -> Dict [str , Dict [str , torch .Tensor ]]:
489
524
r"""Compare output activations between float and quantized models at
490
525
corresponding locations for the same input. Return a dict with key corresponding
@@ -517,7 +552,9 @@ def compare_model_outputs(
517
552
and each entry being a dictionary with two keys 'float' and 'quantized',
518
553
containing the matching float and quantized activations
519
554
"""
520
- torch ._C ._log_api_usage_once ("quantization_api._numeric_suite.compare_model_outputs" )
555
+ torch ._C ._log_api_usage_once (
556
+ "quantization_api._numeric_suite.compare_model_outputs"
557
+ )
521
558
if allow_list is None :
522
559
allow_list = get_default_compare_output_module_list ()
523
560
prepare_model_outputs (float_model , q_model , logger_cls , allow_list )
0 commit comments