Skip to content

Commit

Permalink
[BE] enable UFMT for torch/ao/ (pytorch#128864)
Browse files Browse the repository at this point in the history
Part of pytorch#123062

- pytorch#123062

Pull Request resolved: pytorch#128864
Approved by: https://github.com/ezyang
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Jul 25, 2024
1 parent 434f60c commit c04f70b
Show file tree
Hide file tree
Showing 13 changed files with 1,283 additions and 872 deletions.
14 changes: 0 additions & 14 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1191,20 +1191,6 @@ exclude_patterns = [
'torch/_export/trace.py',
'torch/_export/verifier.py',
'torch/_vendor/**',
'torch/ao/__init__.py',
'torch/ao/ns/__init__.py',
'torch/ao/ns/_numeric_suite.py',
'torch/ao/ns/_numeric_suite_fx.py',
'torch/ao/ns/fx/__init__.py',
'torch/ao/ns/fx/graph_matcher.py',
'torch/ao/ns/fx/graph_passes.py',
'torch/ao/ns/fx/mappings.py',
'torch/ao/ns/fx/n_shadows_utils.py',
'torch/ao/ns/fx/ns_types.py',
'torch/ao/ns/fx/pattern_utils.py',
'torch/ao/ns/fx/qconfig_multi_mapping.py',
'torch/ao/ns/fx/utils.py',
'torch/ao/ns/fx/weight_utils.py',
'torch/compiler/__init__.py',
'torch/contrib/__init__.py',
'torch/contrib/_tensorboard_vis.py',
Expand Down
2 changes: 2 additions & 0 deletions torch/ao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
"pruning",
]


def __getattr__(name):
if name in __all__:
import importlib

return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
89 changes: 63 additions & 26 deletions torch/ao/ns/_numeric_suite.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# mypy: allow-untyped-defs
from typing import Any, Callable, Dict, List, Optional, Set, Union

import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.nn as nn
from torch.ao.quantization import prepare
from typing import Dict, List, Optional, Any, Union, Callable, Set

from torch.ao.quantization.quantization_mappings import (
get_default_compare_output_module_list,
)


NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
nnqd.Linear,
nnq.Linear,
Expand All @@ -19,7 +20,8 @@


def _find_match(
str_list: Union[Dict[str, Any], List[str]], key_str: str,
str_list: Union[Dict[str, Any], List[str]],
key_str: str,
postfix: str,
) -> Optional[str]:
split_str = key_str.split(".")
Expand Down Expand Up @@ -120,7 +122,8 @@ def compare_weights(


def _get_logger_dict_helper(
mod: nn.Module, target_dict: Dict[str, Any],
mod: nn.Module,
target_dict: Dict[str, Any],
prefix: str = "",
) -> None:
r"""This is the helper function for get_logger_dict
Expand Down Expand Up @@ -168,8 +171,7 @@ def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:


class Logger(nn.Module):
r"""Base class for stats logging
"""
r"""Base class for stats logging"""

def __init__(self):
super().__init__()
Expand All @@ -180,8 +182,10 @@ def __init__(self):
self.dtype = torch.quint8

def forward(self, x):
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
pass


Expand All @@ -196,8 +200,10 @@ def __init__(self):
self.stats["quantized"] = []

def forward(self, x, y):
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
if len(x) > 1:
x = x[0]
if len(y) > 1:
Expand All @@ -207,17 +213,17 @@ def forward(self, x, y):


class OutputLogger(Logger):
r"""Class used to log the outputs of the module
"""
r"""Class used to log the outputs of the module"""

def __init__(self):
super().__init__()
self.stats["tensor_val"] = []


def forward(self, x):
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
self.stats["tensor_val"].append(x)
return x

Expand Down Expand Up @@ -256,8 +262,10 @@ def __init__(self, q_module, float_module, logger_cls):
self.logger = logger_cls()

def forward(self, *x) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
xl = _convert_tuple_to_list(x)
output = self.orig_module(*xl)
xl_float = _dequantize_tensor_list(xl)
Expand All @@ -266,8 +274,10 @@ def forward(self, *x) -> torch.Tensor:
return output

def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.add(x, y)
x = x.dequantize()
y = y.dequantize()
Expand All @@ -276,17 +286,21 @@ def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return output

def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.add_scalar(x, y)
x = x.dequantize()
shadow_output = self.shadow_module.add_scalar(x, y)
self.logger(output, shadow_output)
return output

def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.mul(x, y)
x = x.dequantize()
y = y.dequantize()
Expand All @@ -295,26 +309,32 @@ def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return output

def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.mul_scalar(x, y)
x = x.dequantize()
shadow_output = self.shadow_module.mul_scalar(x, y)
self.logger(output, shadow_output)
return output

def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.cat(x, dim)
x = [y.dequantize() for y in x]
shadow_output = self.shadow_module.cat(x, dim)
self.logger(output, shadow_output)
return output

def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.add_relu(x, y)
x = x.dequantize()
y = y.dequantize()
Expand All @@ -324,8 +344,10 @@ def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:


def prepare_model_with_stubs(
float_module: nn.Module, q_module: nn.Module,
module_swap_list: Set[type], logger_cls: Callable,
float_module: nn.Module,
q_module: nn.Module,
module_swap_list: Set[type],
logger_cls: Callable,
) -> None:
r"""Prepare the model by attaching the float module to its matching quantized
module as the shadow if the float module type is in module_swap_list.
Expand All @@ -343,15 +365,16 @@ def prepare_model_with_stubs(
logger_cls: type of logger to be used in shadow module to process the outputs of
quantized module and its float shadow module
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
torch._C._log_api_usage_once(
"quantization_api._numeric_suite.prepare_model_with_stubs"
)

float_module_children = {}
for name, mod in float_module.named_children():
float_module_children[name] = mod

reassign = {}
for name, mod in q_module.named_children():

if name not in float_module_children:
continue

Expand All @@ -362,23 +385,28 @@ def prepare_model_with_stubs(

# Insert shadow module only if the module is not of the same type as
# the floating point module
if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
if type(float_mod) in module_swap_list and not _is_identical_module_type(
mod, float_mod
):
reassign[name] = Shadow(mod, float_mod, logger_cls)

for key, value in reassign.items():
q_module._modules[key] = value


def _is_identical_module_type(mod1, mod2):
# Compare if two modules have the same dtype
mod1_module_types = [type(mod) for mod in mod1.modules()]
mod2_module_types = [type(mod) for mod in mod2.modules()]
return mod1_module_types == mod2_module_types



def compare_model_stub(
float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
*data, logger_cls=ShadowLogger
float_model: nn.Module,
q_model: nn.Module,
module_swap_list: Set[type],
*data,
logger_cls=ShadowLogger,
) -> Dict[str, Dict]:
r"""Compare quantized module in a model with its floating point counterpart,
feeding both of them the same input. Return a dict with key corresponding to
Expand Down Expand Up @@ -419,7 +447,8 @@ def compare_model_stub(


def get_matching_activations(
float_module: nn.Module, q_module: nn.Module,
float_module: nn.Module,
q_module: nn.Module,
) -> Dict[str, Dict[str, torch.Tensor]]:
r"""Find the matching activation between float and quantized modules.
Expand All @@ -432,7 +461,9 @@ def get_matching_activations(
entry being a dictionary with two keys 'float' and 'quantized', containing
the matching float and quantized activations
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
torch._C._log_api_usage_once(
"quantization_api._numeric_suite.get_matching_activations"
)
float_dict = get_logger_dict(float_module)
quantized_dict = get_logger_dict(q_module)
act_dict: Dict[str, Dict] = {}
Expand All @@ -451,7 +482,7 @@ def prepare_model_outputs(
float_module: nn.Module,
q_module: nn.Module,
logger_cls=OutputLogger,
allow_list=None
allow_list=None,
) -> None:
r"""Prepare the model by attaching the logger to both float module
and quantized module if they are in the allow_list.
Expand All @@ -462,20 +493,24 @@ def prepare_model_outputs(
logger_cls: type of logger to be attached to float_module and q_module
allow_list: list of module types to attach logger
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
torch._C._log_api_usage_once(
"quantization_api._numeric_suite.prepare_model_outputs"
)
if allow_list is None:
allow_list = get_default_compare_output_module_list()

qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None)
float_module.qconfig = qconfig_debug # type: ignore[assignment]
prepare(float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={})
prepare(
float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={}
)
q_module.qconfig = qconfig_debug # type: ignore[assignment]
prepare(
q_module,
inplace=True,
allow_list=allow_list,
observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
prepare_custom_config_dict={}
prepare_custom_config_dict={},
)


Expand All @@ -484,7 +519,7 @@ def compare_model_outputs(
q_model: nn.Module,
*data,
logger_cls=OutputLogger,
allow_list=None
allow_list=None,
) -> Dict[str, Dict[str, torch.Tensor]]:
r"""Compare output activations between float and quantized models at
corresponding locations for the same input. Return a dict with key corresponding
Expand Down Expand Up @@ -517,7 +552,9 @@ def compare_model_outputs(
and each entry being a dictionary with two keys 'float' and 'quantized',
containing the matching float and quantized activations
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
torch._C._log_api_usage_once(
"quantization_api._numeric_suite.compare_model_outputs"
)
if allow_list is None:
allow_list = get_default_compare_output_module_list()
prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
Expand Down
Loading

0 comments on commit c04f70b

Please sign in to comment.