1919
2020import json
2121from abc import ABC , abstractmethod
22- from typing import Any , Callable , Dict , Optional , Union
22+ from collections import OrderedDict
23+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
24+
25+ from neural_compressor .common .logger import Logger
26+ from neural_compressor .common .utility import BASE_CONFIG , COMPOSABLE_CONFIG , GLOBAL , LOCAL
27+
28+ logger = Logger ().get_logger ()
2329
24- from neural_compressor .common .utility import BASE_CONFIG , GLOBAL , OPERATOR_NAME
25- from neural_compressor .utils import logger
2630
2731# Dictionary to store registered configurations
2832registered_configs = {}
@@ -57,31 +61,47 @@ class BaseConfig(ABC):
5761 name = BASE_CONFIG
5862
5963 def __init__ (self ) -> None :
60- self .global_config : Optional [BaseConfig ] = None
64+ self ._global_config : Optional [BaseConfig ] = None
6165 # For PyTorch, operator_type is the collective name for module type and functional operation type,
6266 # for example, `torch.nn.Linear`, and `torch.nn.functional.linear`.
63- self .operator_type_config : Dict [Union [str , Callable ], Optional [BaseConfig ]] = {}
64- self .operator_name_config : Dict [str , Optional [BaseConfig ]] = {}
65-
66- def set_operator_name (self , operator_name : str , config : BaseConfig ) -> BaseConfig :
67- self .operator_name_config [operator_name ] = config
68- return self
69-
70- def _set_operator_type (self , operator_type : Union [str , Callable ], config : BaseConfig ) -> BaseConfig :
71- # TODO (Yi), clean the usage
72- # hide it from user, as we can use set_operator_name with regular expression to convert its functionality
73- self .operator_type_config [operator_type ] = config
67+ # local config is the collections of operator_type configs and operator configs
68+ self ._local_config : Dict [str , Optional [BaseConfig ]] = {}
69+
70+ @property
71+ def global_config (self ):
72+ if self ._global_config is None :
73+ self ._global_config = self .__class__ (** self .to_dict ())
74+ return self ._global_config
75+
76+ @global_config .setter
77+ def global_config (self , config ):
78+ self ._global_config = config
79+
80+ @property
81+ def local_config (self ):
82+ return self ._local_config
83+
84+ @local_config .setter
85+ def local_config (self , config ):
86+ self ._local_config = config
87+
88+ def set_local (self , operator_name : str , config : BaseConfig ) -> BaseConfig :
89+ if operator_name in self .local_config :
90+ logger .warning ("The configuration for %s has already been set, update it." , operator_name )
91+ if self .global_config is None :
92+ self .global_config = self .__class__ (** self .to_dict ())
93+ self .local_config [operator_name ] = config
7494 return self
7595
7696 def to_dict (self , params_list = [], operator2str = None ):
7797 result = {}
7898 global_config = {}
7999 for param in params_list :
80100 global_config [param ] = getattr (self , param )
81- if bool (self .operator_name_config ):
82- result [OPERATOR_NAME ] = {}
83- for op_name , config in self .operator_name_config .items ():
84- result [OPERATOR_NAME ][op_name ] = config .to_dict ()
101+ if bool (self .local_config ):
102+ result [LOCAL ] = {}
103+ for op_name , config in self .local_config .items ():
104+ result [LOCAL ][op_name ] = config .to_dict ()
85105 result [GLOBAL ] = global_config
86106 else :
87107 result = global_config
@@ -99,10 +119,10 @@ def from_dict(cls, config_dict, str2operator=None):
99119 The constructed config.
100120 """
101121 config = cls (** config_dict .get (GLOBAL , {}))
102- operator_config = config_dict .get (OPERATOR_NAME , {})
122+ operator_config = config_dict .get (LOCAL , {})
103123 if operator_config :
104124 for op_name , op_config in operator_config .items ():
105- config .set_operator_name (op_name , cls (** op_config ))
125+ config .set_local (op_name , cls (** op_config ))
106126 return config
107127
108128 @classmethod
@@ -120,7 +140,7 @@ def to_json_file(self, filename):
120140 config_dict = self .to_dict ()
121141 with open (filename , "w" , encoding = "utf-8" ) as file :
122142 json .dump (config_dict , file , indent = 4 )
123- logger .info (f "Dump the config into { filename } " )
143+ logger .info ("Dump the config into %s." , filename )
124144
125145 def to_json_string (self , use_diff : bool = False ) -> str :
126146 """Serializes this instance to a JSON string.
@@ -137,7 +157,7 @@ def to_json_string(self, use_diff: bool = False) -> str:
137157 config_dict = self .to_diff_dict (self )
138158 else :
139159 config_dict = self .to_dict ()
140- return json .dumps (config_dict , indent = 2 , sort_keys = True ) + "\n "
160+ return json .dumps (config_dict , indent = 2 ) + "\n "
141161
142162 def __repr__ (self ) -> str :
143163 return f"{ self .__class__ .__name__ } { self .to_json_string ()} "
@@ -154,10 +174,82 @@ def validate(self, user_config: BaseConfig):
154174 pass
155175
156176 def __add__ (self , other : BaseConfig ) -> BaseConfig :
157- # TODO(Yi) implement config add, like RTNWeightOnlyQuantConfig() + GPTQWeightOnlyQuantConfig()
158- pass
177+ if isinstance (other , type (self )):
178+ for op_name , config in other .local_config .items ():
179+ self .set_local (op_name , config )
180+ return self
181+ else :
182+ return ComposableConfig (configs = [self , other ])
183+
184+ def _get_op_name_op_type_config (self ):
185+ op_type_config_dict = dict ()
186+ op_name_config_dict = dict ()
187+ for name , config in self .local_config .items ():
188+ if self ._is_op_type (name ):
189+ op_type_config_dict [name ] = config
190+ else :
191+ op_name_config_dict [name ] = config
192+ return op_type_config_dict , op_name_config_dict
193+
194+ def to_config_mapping (
195+ self , config_list : List [BaseConfig ] = None , model_info : List [Tuple [str , str ]] = None
196+ ) -> OrderedDict [Union [str , Callable ], OrderedDict [str , BaseConfig ]]:
197+ config_mapping = OrderedDict ()
198+ if config_list is None :
199+ config_list = [self ]
200+ for config in config_list :
201+ global_config = config .global_config
202+ op_type_config_dict , op_name_config_dict = config ._get_op_name_op_type_config ()
203+ for op_name , op_type in model_info :
204+ config_mapping .setdefault (op_type , OrderedDict ())[op_name ] = global_config
205+ if op_type in op_type_config_dict :
206+ config_mapping [op_type ][op_name ] = op_name_config_dict [op_type ]
207+ if op_name in op_name_config_dict :
208+ config_mapping [op_type ][op_name ] = op_name_config_dict [op_name ]
209+ return config_mapping
159210
160211 @staticmethod
161212 def _is_op_type (name : str ) -> bool :
162213 # TODO (Yi), ort and tf need override it
163214 return not isinstance (name , str )
215+
216+
217+ class ComposableConfig (BaseConfig ):
218+ name = COMPOSABLE_CONFIG
219+
220+ def __init__ (self , configs : List [BaseConfig ]) -> None :
221+ self .config_list = configs
222+
223+ def __add__ (self , other : BaseConfig ) -> BaseConfig :
224+ if isinstance (other , type (self )):
225+ self .config_list .extend (other .config_list )
226+ else :
227+ self .config_list .append (other )
228+ return self
229+
230+ def to_dict (self , params_list = [], operator2str = None ):
231+ result = {}
232+ for config in self .config_list :
233+ result [config .name ] = config .to_dict ()
234+ return result
235+
236+ @classmethod
237+ def from_dict (cls , config_dict , str2operator = None ):
238+ # TODO(Yi)
239+ pass
240+
241+ def to_json_string (self , use_diff : bool = False ) -> str :
242+ return json .dumps (self .to_dict (), indent = 2 ) + "\n "
243+
244+ def __repr__ (self ) -> str :
245+ return f"{ self .__class__ .__name__ } { self .to_json_string ()} "
246+
247+ def to_config_mapping (
248+ self , config_list : List [BaseConfig ] = None , model_info : List [Tuple [str , str ]] = None
249+ ) -> OrderedDict [str , BaseConfig ]:
250+ return super ().to_config_mapping (self .config_list , model_info )
251+
252+ @classmethod
253+ def register_supported_configs (cls ):
254+ """Add all supported configs."""
255+ raise NotImplementedError
0 commit comments