11# mypy: allow-untyped-defs
2+ from typing import Any , Callable , Dict , List , Optional , Set , Union
3+
24import torch
3- import torch .nn as nn
45import torch .ao .nn .quantized as nnq
56import torch .ao .nn .quantized .dynamic as nnqd
7+ import torch .nn as nn
68from torch .ao .quantization import prepare
7- from typing import Dict , List , Optional , Any , Union , Callable , Set
8-
99from torch .ao .quantization .quantization_mappings import (
1010 get_default_compare_output_module_list ,
1111)
1212
13+
1314NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
1415 nnqd .Linear ,
1516 nnq .Linear ,
1920
2021
2122def _find_match (
22- str_list : Union [Dict [str , Any ], List [str ]], key_str : str ,
23+ str_list : Union [Dict [str , Any ], List [str ]],
24+ key_str : str ,
2325 postfix : str ,
2426) -> Optional [str ]:
2527 split_str = key_str .split ("." )
@@ -120,7 +122,8 @@ def compare_weights(
120122
121123
122124def _get_logger_dict_helper (
123- mod : nn .Module , target_dict : Dict [str , Any ],
125+ mod : nn .Module ,
126+ target_dict : Dict [str , Any ],
124127 prefix : str = "" ,
125128) -> None :
126129 r"""This is the helper function for get_logger_dict
@@ -168,8 +171,7 @@ def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
168171
169172
170173class Logger (nn .Module ):
171- r"""Base class for stats logging
172- """
174+ r"""Base class for stats logging"""
173175
174176 def __init__ (self ):
175177 super ().__init__ ()
@@ -180,8 +182,10 @@ def __init__(self):
180182 self .dtype = torch .quint8
181183
182184 def forward (self , x ):
185+ # fmt: off
183186 """
184187 """ # blank docblock to make autodoc happy
188+ # fmt: on
185189 pass
186190
187191
@@ -196,8 +200,10 @@ def __init__(self):
196200 self .stats ["quantized" ] = []
197201
198202 def forward (self , x , y ):
203+ # fmt: off
199204 """
200205 """ # blank docblock to make autodoc happy
206+ # fmt: on
201207 if len (x ) > 1 :
202208 x = x [0 ]
203209 if len (y ) > 1 :
@@ -207,17 +213,17 @@ def forward(self, x, y):
207213
208214
209215class OutputLogger (Logger ):
210- r"""Class used to log the outputs of the module
211- """
216+ r"""Class used to log the outputs of the module"""
212217
213218 def __init__ (self ):
214219 super ().__init__ ()
215220 self .stats ["tensor_val" ] = []
216221
217-
218222 def forward (self , x ):
223+ # fmt: off
219224 """
220225 """ # blank docblock to make autodoc happy
226+ # fmt: on
221227 self .stats ["tensor_val" ].append (x )
222228 return x
223229
@@ -256,8 +262,10 @@ def __init__(self, q_module, float_module, logger_cls):
256262 self .logger = logger_cls ()
257263
258264 def forward (self , * x ) -> torch .Tensor :
265+ # fmt: off
259266 """
260267 """ # blank docblock to make autodoc happy
268+ # fmt: on
261269 xl = _convert_tuple_to_list (x )
262270 output = self .orig_module (* xl )
263271 xl_float = _dequantize_tensor_list (xl )
@@ -266,8 +274,10 @@ def forward(self, *x) -> torch.Tensor:
266274 return output
267275
268276 def add (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
277+ # fmt: off
269278 """
270279 """ # blank docblock to make autodoc happy
280+ # fmt: on
271281 output = self .orig_module .add (x , y )
272282 x = x .dequantize ()
273283 y = y .dequantize ()
@@ -276,17 +286,21 @@ def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
276286 return output
277287
278288 def add_scalar (self , x : torch .Tensor , y : float ) -> torch .Tensor :
289+ # fmt: off
279290 """
280291 """ # blank docblock to make autodoc happy
292+ # fmt: on
281293 output = self .orig_module .add_scalar (x , y )
282294 x = x .dequantize ()
283295 shadow_output = self .shadow_module .add_scalar (x , y )
284296 self .logger (output , shadow_output )
285297 return output
286298
287299 def mul (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
300+ # fmt: off
288301 """
289302 """ # blank docblock to make autodoc happy
303+ # fmt: on
290304 output = self .orig_module .mul (x , y )
291305 x = x .dequantize ()
292306 y = y .dequantize ()
@@ -295,26 +309,32 @@ def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
295309 return output
296310
297311 def mul_scalar (self , x : torch .Tensor , y : float ) -> torch .Tensor :
312+ # fmt: off
298313 """
299314 """ # blank docblock to make autodoc happy
315+ # fmt: on
300316 output = self .orig_module .mul_scalar (x , y )
301317 x = x .dequantize ()
302318 shadow_output = self .shadow_module .mul_scalar (x , y )
303319 self .logger (output , shadow_output )
304320 return output
305321
306322 def cat (self , x : List [torch .Tensor ], dim : int = 0 ) -> torch .Tensor :
323+ # fmt: off
307324 """
308325 """ # blank docblock to make autodoc happy
326+ # fmt: on
309327 output = self .orig_module .cat (x , dim )
310328 x = [y .dequantize () for y in x ]
311329 shadow_output = self .shadow_module .cat (x , dim )
312330 self .logger (output , shadow_output )
313331 return output
314332
315333 def add_relu (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
334+ # fmt: off
316335 """
317336 """ # blank docblock to make autodoc happy
337+ # fmt: on
318338 output = self .orig_module .add_relu (x , y )
319339 x = x .dequantize ()
320340 y = y .dequantize ()
@@ -324,8 +344,10 @@ def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
324344
325345
326346def prepare_model_with_stubs (
327- float_module : nn .Module , q_module : nn .Module ,
328- module_swap_list : Set [type ], logger_cls : Callable ,
347+ float_module : nn .Module ,
348+ q_module : nn .Module ,
349+ module_swap_list : Set [type ],
350+ logger_cls : Callable ,
329351) -> None :
330352 r"""Prepare the model by attaching the float module to its matching quantized
331353 module as the shadow if the float module type is in module_swap_list.
@@ -343,15 +365,16 @@ def prepare_model_with_stubs(
343365 logger_cls: type of logger to be used in shadow module to process the outputs of
344366 quantized module and its float shadow module
345367 """
346- torch ._C ._log_api_usage_once ("quantization_api._numeric_suite.prepare_model_with_stubs" )
368+ torch ._C ._log_api_usage_once (
369+ "quantization_api._numeric_suite.prepare_model_with_stubs"
370+ )
347371
348372 float_module_children = {}
349373 for name , mod in float_module .named_children ():
350374 float_module_children [name ] = mod
351375
352376 reassign = {}
353377 for name , mod in q_module .named_children ():
354-
355378 if name not in float_module_children :
356379 continue
357380
@@ -362,23 +385,28 @@ def prepare_model_with_stubs(
362385
363386 # Insert shadow module only if the module is not of the same type as
364387 # the floating point module
365- if type (float_mod ) in module_swap_list and not _is_identical_module_type (mod , float_mod ):
388+ if type (float_mod ) in module_swap_list and not _is_identical_module_type (
389+ mod , float_mod
390+ ):
366391 reassign [name ] = Shadow (mod , float_mod , logger_cls )
367392
368393 for key , value in reassign .items ():
369394 q_module ._modules [key ] = value
370395
396+
371397def _is_identical_module_type (mod1 , mod2 ):
372398 # Compare if two modules have the same dtype
373399 mod1_module_types = [type (mod ) for mod in mod1 .modules ()]
374400 mod2_module_types = [type (mod ) for mod in mod2 .modules ()]
375401 return mod1_module_types == mod2_module_types
376402
377403
378-
379404def compare_model_stub (
380- float_model : nn .Module , q_model : nn .Module , module_swap_list : Set [type ],
381- * data , logger_cls = ShadowLogger
405+ float_model : nn .Module ,
406+ q_model : nn .Module ,
407+ module_swap_list : Set [type ],
408+ * data ,
409+ logger_cls = ShadowLogger ,
382410) -> Dict [str , Dict ]:
383411 r"""Compare quantized module in a model with its floating point counterpart,
384412 feeding both of them the same input. Return a dict with key corresponding to
@@ -419,7 +447,8 @@ def compare_model_stub(
419447
420448
421449def get_matching_activations (
422- float_module : nn .Module , q_module : nn .Module ,
450+ float_module : nn .Module ,
451+ q_module : nn .Module ,
423452) -> Dict [str , Dict [str , torch .Tensor ]]:
424453 r"""Find the matching activation between float and quantized modules.
425454
@@ -432,7 +461,9 @@ def get_matching_activations(
432461 entry being a dictionary with two keys 'float' and 'quantized', containing
433462 the matching float and quantized activations
434463 """
435- torch ._C ._log_api_usage_once ("quantization_api._numeric_suite.get_matching_activations" )
464+ torch ._C ._log_api_usage_once (
465+ "quantization_api._numeric_suite.get_matching_activations"
466+ )
436467 float_dict = get_logger_dict (float_module )
437468 quantized_dict = get_logger_dict (q_module )
438469 act_dict : Dict [str , Dict ] = {}
@@ -451,7 +482,7 @@ def prepare_model_outputs(
451482 float_module : nn .Module ,
452483 q_module : nn .Module ,
453484 logger_cls = OutputLogger ,
454- allow_list = None
485+ allow_list = None ,
455486) -> None :
456487 r"""Prepare the model by attaching the logger to both float module
457488 and quantized module if they are in the allow_list.
@@ -462,20 +493,24 @@ def prepare_model_outputs(
462493 logger_cls: type of logger to be attached to float_module and q_module
463494 allow_list: list of module types to attach logger
464495 """
465- torch ._C ._log_api_usage_once ("quantization_api._numeric_suite.prepare_model_outputs" )
496+ torch ._C ._log_api_usage_once (
497+ "quantization_api._numeric_suite.prepare_model_outputs"
498+ )
466499 if allow_list is None :
467500 allow_list = get_default_compare_output_module_list ()
468501
469502 qconfig_debug = torch .ao .quantization .QConfig (activation = logger_cls , weight = None )
470503 float_module .qconfig = qconfig_debug # type: ignore[assignment]
471- prepare (float_module , inplace = True , allow_list = allow_list , prepare_custom_config_dict = {})
504+ prepare (
505+ float_module , inplace = True , allow_list = allow_list , prepare_custom_config_dict = {}
506+ )
472507 q_module .qconfig = qconfig_debug # type: ignore[assignment]
473508 prepare (
474509 q_module ,
475510 inplace = True ,
476511 allow_list = allow_list ,
477512 observer_non_leaf_module_list = NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST ,
478- prepare_custom_config_dict = {}
513+ prepare_custom_config_dict = {},
479514 )
480515
481516
@@ -484,7 +519,7 @@ def compare_model_outputs(
484519 q_model : nn .Module ,
485520 * data ,
486521 logger_cls = OutputLogger ,
487- allow_list = None
522+ allow_list = None ,
488523) -> Dict [str , Dict [str , torch .Tensor ]]:
489524 r"""Compare output activations between float and quantized models at
490525 corresponding locations for the same input. Return a dict with key corresponding
@@ -517,7 +552,9 @@ def compare_model_outputs(
517552 and each entry being a dictionary with two keys 'float' and 'quantized',
518553 containing the matching float and quantized activations
519554 """
520- torch ._C ._log_api_usage_once ("quantization_api._numeric_suite.compare_model_outputs" )
555+ torch ._C ._log_api_usage_once (
556+ "quantization_api._numeric_suite.compare_model_outputs"
557+ )
521558 if allow_list is None :
522559 allow_list = get_default_compare_output_module_list ()
523560 prepare_model_outputs (float_model , q_model , logger_cls , allow_list )
0 commit comments