Skip to content

Commit 8447d70

Browse files
yiliu30chensuyue
andauthored
Enhance 3.x API (#1397)
Signed-off-by: yiliu30 <yi4.liu@intel.com> Co-authored-by: chensuyue <suyue.chen@intel.com>
1 parent 54e4d43 commit 8447d70

File tree

19 files changed

+1438
-63
lines changed

19 files changed

+1438
-63
lines changed

.azure-pipelines/scripts/ut/run_3x_pt.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ echo "${test_case}"
66
# install requirements
77
echo "set up UT env..."
88
pip install -r /neural-compressor/requirements_pt.txt
9+
pip install transformers
910
pip install coverage
1011
pip install pytest
1112
pip list

.azure-pipelines/ut-3x-pt.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ pr:
1414
- setup.py
1515
- requirements.txt
1616
- requirements_pt.txt
17-
- .azure-pipelines/scripts/ut
1817

1918
pool: ICX-16C
2019

.azure-pipelines/ut-basic-no-cover.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ pr:
1212
- test
1313
- setup.py
1414
- requirements.txt
15-
- .azure-pipelines/scripts/ut
1615
exclude:
1716
- test/neural_coder
1817
- test/3x

.azure-pipelines/ut-basic.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ pr:
1212
- test
1313
- setup.py
1414
- requirements.txt
15-
- .azure-pipelines/scripts/ut
1615
exclude:
1716
- test/neural_coder
1817
- test/3x

neural_compressor/common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from neural_compressor.common.logger import level, log, info, debug, warn, warning, error, fatal

neural_compressor/common/base_config.py

Lines changed: 117 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@
1919

2020
import json
2121
from abc import ABC, abstractmethod
22-
from typing import Any, Callable, Dict, Optional, Union
22+
from collections import OrderedDict
23+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
24+
25+
from neural_compressor.common.logger import Logger
26+
from neural_compressor.common.utility import BASE_CONFIG, COMPOSABLE_CONFIG, GLOBAL, LOCAL
27+
28+
logger = Logger().get_logger()
2329

24-
from neural_compressor.common.utility import BASE_CONFIG, GLOBAL, OPERATOR_NAME
25-
from neural_compressor.utils import logger
2630

2731
# Dictionary to store registered configurations
2832
registered_configs = {}
@@ -57,31 +61,47 @@ class BaseConfig(ABC):
5761
name = BASE_CONFIG
5862

5963
def __init__(self) -> None:
60-
self.global_config: Optional[BaseConfig] = None
64+
self._global_config: Optional[BaseConfig] = None
6165
# For PyTorch, operator_type is the collective name for module type and functional operation type,
6266
# for example, `torch.nn.Linear`, and `torch.nn.functional.linear`.
63-
self.operator_type_config: Dict[Union[str, Callable], Optional[BaseConfig]] = {}
64-
self.operator_name_config: Dict[str, Optional[BaseConfig]] = {}
65-
66-
def set_operator_name(self, operator_name: str, config: BaseConfig) -> BaseConfig:
67-
self.operator_name_config[operator_name] = config
68-
return self
69-
70-
def _set_operator_type(self, operator_type: Union[str, Callable], config: BaseConfig) -> BaseConfig:
71-
# TODO (Yi), clean the usage
72-
# hide it from user, as we can use set_operator_name with regular expression to convert its functionality
73-
self.operator_type_config[operator_type] = config
67+
# local config is the collections of operator_type configs and operator configs
68+
self._local_config: Dict[str, Optional[BaseConfig]] = {}
69+
70+
@property
71+
def global_config(self):
72+
if self._global_config is None:
73+
self._global_config = self.__class__(**self.to_dict())
74+
return self._global_config
75+
76+
@global_config.setter
77+
def global_config(self, config):
78+
self._global_config = config
79+
80+
@property
81+
def local_config(self):
82+
return self._local_config
83+
84+
@local_config.setter
85+
def local_config(self, config):
86+
self._local_config = config
87+
88+
def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig:
89+
if operator_name in self.local_config:
90+
logger.warning("The configuration for %s has already been set, update it.", operator_name)
91+
if self.global_config is None:
92+
self.global_config = self.__class__(**self.to_dict())
93+
self.local_config[operator_name] = config
7494
return self
7595

7696
def to_dict(self, params_list=[], operator2str=None):
7797
result = {}
7898
global_config = {}
7999
for param in params_list:
80100
global_config[param] = getattr(self, param)
81-
if bool(self.operator_name_config):
82-
result[OPERATOR_NAME] = {}
83-
for op_name, config in self.operator_name_config.items():
84-
result[OPERATOR_NAME][op_name] = config.to_dict()
101+
if bool(self.local_config):
102+
result[LOCAL] = {}
103+
for op_name, config in self.local_config.items():
104+
result[LOCAL][op_name] = config.to_dict()
85105
result[GLOBAL] = global_config
86106
else:
87107
result = global_config
@@ -99,10 +119,10 @@ def from_dict(cls, config_dict, str2operator=None):
99119
The constructed config.
100120
"""
101121
config = cls(**config_dict.get(GLOBAL, {}))
102-
operator_config = config_dict.get(OPERATOR_NAME, {})
122+
operator_config = config_dict.get(LOCAL, {})
103123
if operator_config:
104124
for op_name, op_config in operator_config.items():
105-
config.set_operator_name(op_name, cls(**op_config))
125+
config.set_local(op_name, cls(**op_config))
106126
return config
107127

108128
@classmethod
@@ -120,7 +140,7 @@ def to_json_file(self, filename):
120140
config_dict = self.to_dict()
121141
with open(filename, "w", encoding="utf-8") as file:
122142
json.dump(config_dict, file, indent=4)
123-
logger.info(f"Dump the config into {filename}")
143+
logger.info("Dump the config into %s.", filename)
124144

125145
def to_json_string(self, use_diff: bool = False) -> str:
126146
"""Serializes this instance to a JSON string.
@@ -137,7 +157,7 @@ def to_json_string(self, use_diff: bool = False) -> str:
137157
config_dict = self.to_diff_dict(self)
138158
else:
139159
config_dict = self.to_dict()
140-
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
160+
return json.dumps(config_dict, indent=2) + "\n"
141161

142162
def __repr__(self) -> str:
143163
return f"{self.__class__.__name__} {self.to_json_string()}"
@@ -154,10 +174,82 @@ def validate(self, user_config: BaseConfig):
154174
pass
155175

156176
def __add__(self, other: BaseConfig) -> BaseConfig:
157-
# TODO(Yi) implement config add, like RTNWeightOnlyQuantConfig() + GPTQWeightOnlyQuantConfig()
158-
pass
177+
if isinstance(other, type(self)):
178+
for op_name, config in other.local_config.items():
179+
self.set_local(op_name, config)
180+
return self
181+
else:
182+
return ComposableConfig(configs=[self, other])
183+
184+
def _get_op_name_op_type_config(self):
185+
op_type_config_dict = dict()
186+
op_name_config_dict = dict()
187+
for name, config in self.local_config.items():
188+
if self._is_op_type(name):
189+
op_type_config_dict[name] = config
190+
else:
191+
op_name_config_dict[name] = config
192+
return op_type_config_dict, op_name_config_dict
193+
194+
def to_config_mapping(
195+
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
196+
) -> OrderedDict[Union[str, Callable], OrderedDict[str, BaseConfig]]:
197+
config_mapping = OrderedDict()
198+
if config_list is None:
199+
config_list = [self]
200+
for config in config_list:
201+
global_config = config.global_config
202+
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
203+
for op_name, op_type in model_info:
204+
config_mapping.setdefault(op_type, OrderedDict())[op_name] = global_config
205+
if op_type in op_type_config_dict:
206+
config_mapping[op_type][op_name] = op_name_config_dict[op_type]
207+
if op_name in op_name_config_dict:
208+
config_mapping[op_type][op_name] = op_name_config_dict[op_name]
209+
return config_mapping
159210

160211
@staticmethod
161212
def _is_op_type(name: str) -> bool:
162213
# TODO (Yi), ort and tf need override it
163214
return not isinstance(name, str)
215+
216+
217+
class ComposableConfig(BaseConfig):
218+
name = COMPOSABLE_CONFIG
219+
220+
def __init__(self, configs: List[BaseConfig]) -> None:
221+
self.config_list = configs
222+
223+
def __add__(self, other: BaseConfig) -> BaseConfig:
224+
if isinstance(other, type(self)):
225+
self.config_list.extend(other.config_list)
226+
else:
227+
self.config_list.append(other)
228+
return self
229+
230+
def to_dict(self, params_list=[], operator2str=None):
231+
result = {}
232+
for config in self.config_list:
233+
result[config.name] = config.to_dict()
234+
return result
235+
236+
@classmethod
237+
def from_dict(cls, config_dict, str2operator=None):
238+
# TODO(Yi)
239+
pass
240+
241+
def to_json_string(self, use_diff: bool = False) -> str:
242+
return json.dumps(self.to_dict(), indent=2) + "\n"
243+
244+
def __repr__(self) -> str:
245+
return f"{self.__class__.__name__} {self.to_json_string()}"
246+
247+
def to_config_mapping(
248+
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
249+
) -> OrderedDict[str, BaseConfig]:
250+
return super().to_config_mapping(self.config_list, model_info)
251+
252+
@classmethod
253+
def register_supported_configs(cls):
254+
"""Add all supported configs."""
255+
raise NotImplementedError

neural_compressor/common/logger.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
"""Logger: handles logging functionalities."""
18+
19+
import logging
20+
import os
21+
22+
23+
class Logger(object):
24+
"""Logger class."""
25+
26+
__instance = None
27+
28+
def __new__(cls):
29+
"""Create a singleton Logger instance."""
30+
if Logger.__instance is None:
31+
Logger.__instance = object.__new__(cls)
32+
Logger.__instance._log()
33+
return Logger.__instance
34+
35+
def _log(self):
36+
"""Setup the logger format and handler."""
37+
LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
38+
self._logger = logging.getLogger("neural_compressor")
39+
self._logger.handlers.clear()
40+
self._logger.setLevel(LOGLEVEL)
41+
formatter = logging.Formatter(
42+
"%(asctime)s [%(levelname)s][%(filename)s:%(lineno)d] %(message)s", "%Y-%m-%d %H:%M:%S"
43+
)
44+
streamHandler = logging.StreamHandler()
45+
streamHandler.setFormatter(formatter)
46+
self._logger.addHandler(streamHandler)
47+
self._logger.propagate = False
48+
49+
def get_logger(self):
50+
"""Get the logger."""
51+
return self._logger
52+
53+
54+
def _pretty_dict(value, indent=0):
55+
"""Make the logger dict pretty."""
56+
prefix = "\n" + " " * (indent + 4)
57+
if isinstance(value, dict):
58+
items = [prefix + repr(key) + ": " + _pretty_dict(value[key], indent + 4) for key in value]
59+
return "{%s}" % (",".join(items) + "\n" + " " * indent)
60+
elif isinstance(value, list):
61+
items = [prefix + _pretty_dict(item, indent + 4) for item in value]
62+
return "[%s]" % (",".join(items) + "\n" + " " * indent)
63+
elif isinstance(value, tuple):
64+
items = [prefix + _pretty_dict(item, indent + 4) for item in value]
65+
return "(%s)" % (",".join(items) + "\n" + " " * indent)
66+
else:
67+
return repr(value)
68+
69+
70+
level = Logger().get_logger().level
71+
DEBUG = logging.DEBUG
72+
73+
74+
def log(level, msg, *args, **kwargs):
75+
"""Output log with the level as a parameter."""
76+
if isinstance(msg, dict):
77+
for _, line in enumerate(_pretty_dict(msg).split("\n")):
78+
Logger().get_logger().log(level, line, *args, **kwargs)
79+
else:
80+
Logger().get_logger().log(level, msg, *args, **kwargs)
81+
82+
83+
def debug(msg, *args, **kwargs):
84+
"""Output log with the debug level."""
85+
if isinstance(msg, dict):
86+
for _, line in enumerate(_pretty_dict(msg).split("\n")):
87+
Logger().get_logger().debug(line, *args, **kwargs)
88+
else:
89+
Logger().get_logger().debug(msg, *args, **kwargs)
90+
91+
92+
def error(msg, *args, **kwargs):
93+
"""Output log with the error level."""
94+
if isinstance(msg, dict):
95+
for _, line in enumerate(_pretty_dict(msg).split("\n")):
96+
Logger().get_logger().error(line, *args, **kwargs)
97+
else:
98+
Logger().get_logger().error(msg, *args, **kwargs)
99+
100+
101+
def fatal(msg, *args, **kwargs):
102+
"""Output log with the fatal level."""
103+
if isinstance(msg, dict):
104+
for _, line in enumerate(_pretty_dict(msg).split("\n")):
105+
Logger().get_logger().fatal(line, *args, **kwargs)
106+
else:
107+
Logger().get_logger().fatal(msg, *args, **kwargs)
108+
109+
110+
def info(msg, *args, **kwargs):
111+
"""Output log with the info level."""
112+
if isinstance(msg, dict):
113+
for _, line in enumerate(_pretty_dict(msg).split("\n")):
114+
Logger().get_logger().info(line, *args, **kwargs)
115+
else:
116+
Logger().get_logger().info(msg, *args, **kwargs)
117+
118+
119+
def warn(msg, *args, **kwargs):
120+
"""Output log with the warning level."""
121+
if isinstance(msg, dict):
122+
for _, line in enumerate(_pretty_dict(msg).split("\n")):
123+
Logger().get_logger().warning(line, *args, **kwargs)
124+
else:
125+
Logger().get_logger().warning(msg, *args, **kwargs)
126+
127+
128+
def warning(msg, *args, **kwargs):
129+
"""Output log with the warning level (Alias of the method warn)."""
130+
if isinstance(msg, dict):
131+
for _, line in enumerate(_pretty_dict(msg).split("\n")):
132+
Logger().get_logger().warning(line, *args, **kwargs)
133+
else:
134+
Logger().get_logger().warning(msg, *args, **kwargs)

neural_compressor/common/utility.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020

2121
# constants for configs
2222
GLOBAL = "global"
23-
OPERATOR_NAME = "operator_name"
23+
LOCAL = "local"
2424

2525
# config name
2626
BASE_CONFIG = "base_config"
27+
COMPOSABLE_CONFIG = "composable_config"
2728
RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant"
29+
DUMMY_CONFIG = "dummy_config"

neural_compressor/torch/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,10 @@
1515
from neural_compressor.torch.utils import register_algo
1616
from neural_compressor.torch.algorithms import rtn_quantize_entry
1717

18-
from neural_compressor.torch.quantization import quantize, RTNWeightQuantConfig, get_default_rtn_config
18+
from neural_compressor.torch.quantization import (
19+
quantize,
20+
RTNWeightQuantConfig,
21+
get_default_rtn_config,
22+
DummyConfig,
23+
get_default_dummy_config,
24+
)

0 commit comments

Comments
 (0)