Skip to content

Commit

Permalink
[mypy] hwconfig (#3189)
Browse files Browse the repository at this point in the history
### Changes

Enable mypy check for `nncf/common/hardware/config.py`
  • Loading branch information
AlexanderDokuchaev authored Jan 14, 2025
1 parent 2a5ee2a commit f355847
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 22 deletions.
32 changes: 17 additions & 15 deletions nncf/common/hardware/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type

import jstyleson as json
import jstyleson as json # type: ignore[import-untyped]

import nncf
from nncf.common.graph.operator_metatypes import OperatorMetatype
Expand Down Expand Up @@ -60,7 +60,7 @@ def get_hw_config_type(target_device: str) -> Optional[HWConfigType]:
return HWConfigType(HW_CONFIG_TYPE_TARGET_DEVICE_MAP[target_device])


class HWConfig(list, ABC):
class HWConfig(list[Dict[str, Any]], ABC):
QUANTIZATION_ALGORITHM_NAME = "quantization"
ATTRIBUTES_NAME = "attributes"
SCALE_ATTRIBUTE_NAME = "scales"
Expand All @@ -69,23 +69,23 @@ class HWConfig(list, ABC):

TYPE_TO_CONF_NAME_DICT = {HWConfigType.CPU: "cpu.json", HWConfigType.NPU: "npu.json", HWConfigType.GPU: "gpu.json"}

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.registered_algorithm_configs = {}
self.registered_algorithm_configs: Dict[str, Any] = {}
self.target_device = None

@abstractmethod
def _get_available_operator_metatypes_for_matching(self) -> List[Type[OperatorMetatype]]:
pass

@staticmethod
def get_path_to_hw_config(hw_config_type: HWConfigType):
def get_path_to_hw_config(hw_config_type: HWConfigType) -> str:
return "/".join(
[NNCF_PACKAGE_ROOT_DIR, HW_CONFIG_RELATIVE_DIR, HWConfig.TYPE_TO_CONF_NAME_DICT[hw_config_type]]
)

@classmethod
def from_dict(cls, dct: dict):
def from_dict(cls, dct: Dict[str, Any]) -> "HWConfig":
hw_config = cls()
hw_config.target_device = dct["target_device"]

Expand All @@ -104,7 +104,7 @@ def from_dict(cls, dct: dict):
for algorithm_name in op_dict:
if algorithm_name not in hw_config.registered_algorithm_configs:
continue
tmp_config = {}
tmp_config: Dict[str, List[Dict[str, Any]]] = {}
for algo_and_op_specific_field_name, algorithm_configs in op_dict[algorithm_name].items():
if not isinstance(algorithm_configs, list):
algorithm_configs = [algorithm_configs]
Expand All @@ -129,30 +129,30 @@ def from_dict(cls, dct: dict):
return hw_config

@classmethod
def from_json(cls, path):
def from_json(cls: type["HWConfig"], path: str) -> List[Dict[str, Any]]:
file_path = Path(path).resolve()
with safe_open(file_path) as f:
json_config = json.load(f, object_pairs_hook=OrderedDict)
return cls.from_dict(json_config)

@staticmethod
def get_quantization_mode_from_config_value(str_val: str):
def get_quantization_mode_from_config_value(str_val: str) -> str:
if str_val == "symmetric":
return QuantizationMode.SYMMETRIC
if str_val == "asymmetric":
return QuantizationMode.ASYMMETRIC
raise nncf.ValidationError("Invalid quantization type specified in HW config")

@staticmethod
def get_is_per_channel_from_config_value(str_val: str):
def get_is_per_channel_from_config_value(str_val: str) -> bool:
if str_val == "perchannel":
return True
if str_val == "pertensor":
return False
raise nncf.ValidationError("Invalid quantization granularity specified in HW config")

@staticmethod
def get_qconf_from_hw_config_subdict(quantization_subdict: Dict):
def get_qconf_from_hw_config_subdict(quantization_subdict: Dict[str, Any]) -> QuantizerConfig:
bits = quantization_subdict["bits"]
mode = HWConfig.get_quantization_mode_from_config_value(quantization_subdict["mode"])
is_per_channel = HWConfig.get_is_per_channel_from_config_value(quantization_subdict["granularity"])
Expand Down Expand Up @@ -181,20 +181,22 @@ def get_qconf_from_hw_config_subdict(quantization_subdict: Dict):
)

@staticmethod
def is_qconf_list_corresponding_to_unspecified_op(qconf_list: Optional[List[QuantizerConfig]]):
def is_qconf_list_corresponding_to_unspecified_op(qconf_list: Optional[List[QuantizerConfig]]) -> bool:
return qconf_list is None

@staticmethod
def is_wildcard_quantization(qconf_list: Optional[List[QuantizerConfig]]):
def is_wildcard_quantization(qconf_list: Optional[List[QuantizerConfig]]) -> bool:
# Corresponds to an op itself being specified in the HW config, but having no associated quantization
# configs specified
return qconf_list is not None and len(qconf_list) == 0

def get_metatype_vs_quantizer_configs_map(
self, for_weights=False
self, for_weights: bool = False
) -> Dict[Type[OperatorMetatype], Optional[List[QuantizerConfig]]]:
# 'None' for ops unspecified in HW config, empty list for wildcard quantization ops
retval = {k: None for k in self._get_available_operator_metatypes_for_matching()}
retval: Dict[Type[OperatorMetatype], Optional[List[QuantizerConfig]]] = {
k: None for k in self._get_available_operator_metatypes_for_matching()
}
config_key = "weights" if for_weights else "activations"
for op_dict in self:
hw_config_op_name = op_dict["type"]
Expand Down
4 changes: 2 additions & 2 deletions nncf/common/quantization/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import nncf
from nncf.common.graph import NNCFNode
Expand Down Expand Up @@ -45,7 +45,7 @@ class QuantizerConfig:
def __init__(
self,
num_bits: int = QUANTIZATION_BITS,
mode: QuantizationScheme = QuantizationScheme.SYMMETRIC,
mode: Union[QuantizationScheme, str] = QuantizationScheme.SYMMETRIC, # TODO(AlexanderDokuchaev): use enum
signedness_to_force: Optional[bool] = None,
per_channel: bool = QUANTIZATION_PER_CHANNEL,
):
Expand Down
10 changes: 6 additions & 4 deletions nncf/common/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
import itertools
import os
import os.path as osp
import pathlib
from typing import Any, Dict, Hashable, Iterable, List, Optional, Union
from pathlib import Path
from typing import Any, Dict, Hashable, Iterable, List, Optional, TypeVar, Union

from tabulate import tabulate

from nncf.common.utils.os import is_windows

TKey = TypeVar("TKey", bound=Hashable)


def create_table(
header: List[str],
Expand All @@ -44,7 +46,7 @@ def create_table(
return tabulate(tabular_data=rows, headers=header, tablefmt=table_fmt, maxcolwidths=max_col_widths, floatfmt=".3f")


def configure_accuracy_aware_paths(log_dir: Union[str, pathlib.Path]) -> Union[str, pathlib.Path]:
def configure_accuracy_aware_paths(log_dir: Union[str, Path]) -> Union[str, Path]:
"""
Create a subdirectory inside of the passed log directory
to save checkpoints from the accuracy-aware training loop to.
Expand All @@ -59,7 +61,7 @@ def configure_accuracy_aware_paths(log_dir: Union[str, pathlib.Path]) -> Union[s
return acc_aware_log_dir


def product_dict(d: Dict[Hashable, List[str]]) -> Iterable[Dict[Hashable, str]]:
def product_dict(d: Dict[TKey, List[Any]]) -> Iterable[Dict[TKey, Any]]:
"""
Generates dicts which enumerate the options for keys given in the input dict;
options are represented by list values in the input dict.
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ exclude = [
"nncf/common/composite_compression.py",
"nncf/common/compression.py",
"nncf/common/deprecation.py",
"nncf/common/hardware/config.py",
"nncf/common/logging/progress_bar.py",
"nncf/common/logging/track_progress.py",
"nncf/common/pruning/clusterization.py",
Expand Down

0 comments on commit f355847

Please sign in to comment.