Skip to content

Commit c04f70b

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[BE] enable UFMT for torch/ao/ (pytorch#128864)
Part of pytorch#123062 - pytorch#123062 Pull Request resolved: pytorch#128864 Approved by: https://github.com/ezyang
1 parent 434f60c commit c04f70b

13 files changed

+1283
-872
lines changed

.lintrunner.toml

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,20 +1191,6 @@ exclude_patterns = [
11911191
'torch/_export/trace.py',
11921192
'torch/_export/verifier.py',
11931193
'torch/_vendor/**',
1194-
'torch/ao/__init__.py',
1195-
'torch/ao/ns/__init__.py',
1196-
'torch/ao/ns/_numeric_suite.py',
1197-
'torch/ao/ns/_numeric_suite_fx.py',
1198-
'torch/ao/ns/fx/__init__.py',
1199-
'torch/ao/ns/fx/graph_matcher.py',
1200-
'torch/ao/ns/fx/graph_passes.py',
1201-
'torch/ao/ns/fx/mappings.py',
1202-
'torch/ao/ns/fx/n_shadows_utils.py',
1203-
'torch/ao/ns/fx/ns_types.py',
1204-
'torch/ao/ns/fx/pattern_utils.py',
1205-
'torch/ao/ns/fx/qconfig_multi_mapping.py',
1206-
'torch/ao/ns/fx/utils.py',
1207-
'torch/ao/ns/fx/weight_utils.py',
12081194
'torch/compiler/__init__.py',
12091195
'torch/contrib/__init__.py',
12101196
'torch/contrib/_tensorboard_vis.py',

torch/ao/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
"pruning",
1111
]
1212

13+
1314
def __getattr__(name):
1415
if name in __all__:
1516
import importlib
17+
1618
return importlib.import_module("." + name, __name__)
1719
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

torch/ao/ns/_numeric_suite.py

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# mypy: allow-untyped-defs
2+
from typing import Any, Callable, Dict, List, Optional, Set, Union
3+
24
import torch
3-
import torch.nn as nn
45
import torch.ao.nn.quantized as nnq
56
import torch.ao.nn.quantized.dynamic as nnqd
7+
import torch.nn as nn
68
from torch.ao.quantization import prepare
7-
from typing import Dict, List, Optional, Any, Union, Callable, Set
8-
99
from torch.ao.quantization.quantization_mappings import (
1010
get_default_compare_output_module_list,
1111
)
1212

13+
1314
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
1415
nnqd.Linear,
1516
nnq.Linear,
@@ -19,7 +20,8 @@
1920

2021

2122
def _find_match(
22-
str_list: Union[Dict[str, Any], List[str]], key_str: str,
23+
str_list: Union[Dict[str, Any], List[str]],
24+
key_str: str,
2325
postfix: str,
2426
) -> Optional[str]:
2527
split_str = key_str.split(".")
@@ -120,7 +122,8 @@ def compare_weights(
120122

121123

122124
def _get_logger_dict_helper(
123-
mod: nn.Module, target_dict: Dict[str, Any],
125+
mod: nn.Module,
126+
target_dict: Dict[str, Any],
124127
prefix: str = "",
125128
) -> None:
126129
r"""This is the helper function for get_logger_dict
@@ -168,8 +171,7 @@ def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
168171

169172

170173
class Logger(nn.Module):
171-
r"""Base class for stats logging
172-
"""
174+
r"""Base class for stats logging"""
173175

174176
def __init__(self):
175177
super().__init__()
@@ -180,8 +182,10 @@ def __init__(self):
180182
self.dtype = torch.quint8
181183

182184
def forward(self, x):
185+
# fmt: off
183186
"""
184187
""" # blank docblock to make autodoc happy
188+
# fmt: on
185189
pass
186190

187191

@@ -196,8 +200,10 @@ def __init__(self):
196200
self.stats["quantized"] = []
197201

198202
def forward(self, x, y):
203+
# fmt: off
199204
"""
200205
""" # blank docblock to make autodoc happy
206+
# fmt: on
201207
if len(x) > 1:
202208
x = x[0]
203209
if len(y) > 1:
@@ -207,17 +213,17 @@ def forward(self, x, y):
207213

208214

209215
class OutputLogger(Logger):
210-
r"""Class used to log the outputs of the module
211-
"""
216+
r"""Class used to log the outputs of the module"""
212217

213218
def __init__(self):
214219
super().__init__()
215220
self.stats["tensor_val"] = []
216221

217-
218222
def forward(self, x):
223+
# fmt: off
219224
"""
220225
""" # blank docblock to make autodoc happy
226+
# fmt: on
221227
self.stats["tensor_val"].append(x)
222228
return x
223229

@@ -256,8 +262,10 @@ def __init__(self, q_module, float_module, logger_cls):
256262
self.logger = logger_cls()
257263

258264
def forward(self, *x) -> torch.Tensor:
265+
# fmt: off
259266
"""
260267
""" # blank docblock to make autodoc happy
268+
# fmt: on
261269
xl = _convert_tuple_to_list(x)
262270
output = self.orig_module(*xl)
263271
xl_float = _dequantize_tensor_list(xl)
@@ -266,8 +274,10 @@ def forward(self, *x) -> torch.Tensor:
266274
return output
267275

268276
def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
277+
# fmt: off
269278
"""
270279
""" # blank docblock to make autodoc happy
280+
# fmt: on
271281
output = self.orig_module.add(x, y)
272282
x = x.dequantize()
273283
y = y.dequantize()
@@ -276,17 +286,21 @@ def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
276286
return output
277287

278288
def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
289+
# fmt: off
279290
"""
280291
""" # blank docblock to make autodoc happy
292+
# fmt: on
281293
output = self.orig_module.add_scalar(x, y)
282294
x = x.dequantize()
283295
shadow_output = self.shadow_module.add_scalar(x, y)
284296
self.logger(output, shadow_output)
285297
return output
286298

287299
def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
300+
# fmt: off
288301
"""
289302
""" # blank docblock to make autodoc happy
303+
# fmt: on
290304
output = self.orig_module.mul(x, y)
291305
x = x.dequantize()
292306
y = y.dequantize()
@@ -295,26 +309,32 @@ def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
295309
return output
296310

297311
def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
312+
# fmt: off
298313
"""
299314
""" # blank docblock to make autodoc happy
315+
# fmt: on
300316
output = self.orig_module.mul_scalar(x, y)
301317
x = x.dequantize()
302318
shadow_output = self.shadow_module.mul_scalar(x, y)
303319
self.logger(output, shadow_output)
304320
return output
305321

306322
def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
323+
# fmt: off
307324
"""
308325
""" # blank docblock to make autodoc happy
326+
# fmt: on
309327
output = self.orig_module.cat(x, dim)
310328
x = [y.dequantize() for y in x]
311329
shadow_output = self.shadow_module.cat(x, dim)
312330
self.logger(output, shadow_output)
313331
return output
314332

315333
def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
334+
# fmt: off
316335
"""
317336
""" # blank docblock to make autodoc happy
337+
# fmt: on
318338
output = self.orig_module.add_relu(x, y)
319339
x = x.dequantize()
320340
y = y.dequantize()
@@ -324,8 +344,10 @@ def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
324344

325345

326346
def prepare_model_with_stubs(
327-
float_module: nn.Module, q_module: nn.Module,
328-
module_swap_list: Set[type], logger_cls: Callable,
347+
float_module: nn.Module,
348+
q_module: nn.Module,
349+
module_swap_list: Set[type],
350+
logger_cls: Callable,
329351
) -> None:
330352
r"""Prepare the model by attaching the float module to its matching quantized
331353
module as the shadow if the float module type is in module_swap_list.
@@ -343,15 +365,16 @@ def prepare_model_with_stubs(
343365
logger_cls: type of logger to be used in shadow module to process the outputs of
344366
quantized module and its float shadow module
345367
"""
346-
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
368+
torch._C._log_api_usage_once(
369+
"quantization_api._numeric_suite.prepare_model_with_stubs"
370+
)
347371

348372
float_module_children = {}
349373
for name, mod in float_module.named_children():
350374
float_module_children[name] = mod
351375

352376
reassign = {}
353377
for name, mod in q_module.named_children():
354-
355378
if name not in float_module_children:
356379
continue
357380

@@ -362,23 +385,28 @@ def prepare_model_with_stubs(
362385

363386
# Insert shadow module only if the module is not of the same type as
364387
# the floating point module
365-
if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
388+
if type(float_mod) in module_swap_list and not _is_identical_module_type(
389+
mod, float_mod
390+
):
366391
reassign[name] = Shadow(mod, float_mod, logger_cls)
367392

368393
for key, value in reassign.items():
369394
q_module._modules[key] = value
370395

396+
371397
def _is_identical_module_type(mod1, mod2):
372398
# Compare if two modules have the same dtype
373399
mod1_module_types = [type(mod) for mod in mod1.modules()]
374400
mod2_module_types = [type(mod) for mod in mod2.modules()]
375401
return mod1_module_types == mod2_module_types
376402

377403

378-
379404
def compare_model_stub(
380-
float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
381-
*data, logger_cls=ShadowLogger
405+
float_model: nn.Module,
406+
q_model: nn.Module,
407+
module_swap_list: Set[type],
408+
*data,
409+
logger_cls=ShadowLogger,
382410
) -> Dict[str, Dict]:
383411
r"""Compare quantized module in a model with its floating point counterpart,
384412
feeding both of them the same input. Return a dict with key corresponding to
@@ -419,7 +447,8 @@ def compare_model_stub(
419447

420448

421449
def get_matching_activations(
422-
float_module: nn.Module, q_module: nn.Module,
450+
float_module: nn.Module,
451+
q_module: nn.Module,
423452
) -> Dict[str, Dict[str, torch.Tensor]]:
424453
r"""Find the matching activation between float and quantized modules.
425454
@@ -432,7 +461,9 @@ def get_matching_activations(
432461
entry being a dictionary with two keys 'float' and 'quantized', containing
433462
the matching float and quantized activations
434463
"""
435-
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
464+
torch._C._log_api_usage_once(
465+
"quantization_api._numeric_suite.get_matching_activations"
466+
)
436467
float_dict = get_logger_dict(float_module)
437468
quantized_dict = get_logger_dict(q_module)
438469
act_dict: Dict[str, Dict] = {}
@@ -451,7 +482,7 @@ def prepare_model_outputs(
451482
float_module: nn.Module,
452483
q_module: nn.Module,
453484
logger_cls=OutputLogger,
454-
allow_list=None
485+
allow_list=None,
455486
) -> None:
456487
r"""Prepare the model by attaching the logger to both float module
457488
and quantized module if they are in the allow_list.
@@ -462,20 +493,24 @@ def prepare_model_outputs(
462493
logger_cls: type of logger to be attached to float_module and q_module
463494
allow_list: list of module types to attach logger
464495
"""
465-
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
496+
torch._C._log_api_usage_once(
497+
"quantization_api._numeric_suite.prepare_model_outputs"
498+
)
466499
if allow_list is None:
467500
allow_list = get_default_compare_output_module_list()
468501

469502
qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None)
470503
float_module.qconfig = qconfig_debug # type: ignore[assignment]
471-
prepare(float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={})
504+
prepare(
505+
float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={}
506+
)
472507
q_module.qconfig = qconfig_debug # type: ignore[assignment]
473508
prepare(
474509
q_module,
475510
inplace=True,
476511
allow_list=allow_list,
477512
observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
478-
prepare_custom_config_dict={}
513+
prepare_custom_config_dict={},
479514
)
480515

481516

@@ -484,7 +519,7 @@ def compare_model_outputs(
484519
q_model: nn.Module,
485520
*data,
486521
logger_cls=OutputLogger,
487-
allow_list=None
522+
allow_list=None,
488523
) -> Dict[str, Dict[str, torch.Tensor]]:
489524
r"""Compare output activations between float and quantized models at
490525
corresponding locations for the same input. Return a dict with key corresponding
@@ -517,7 +552,9 @@ def compare_model_outputs(
517552
and each entry being a dictionary with two keys 'float' and 'quantized',
518553
containing the matching float and quantized activations
519554
"""
520-
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
555+
torch._C._log_api_usage_once(
556+
"quantization_api._numeric_suite.compare_model_outputs"
557+
)
521558
if allow_list is None:
522559
allow_list = get_default_compare_output_module_list()
523560
prepare_model_outputs(float_model, q_model, logger_cls, allow_list)

0 commit comments

Comments
 (0)