1717"""Initialize critetion classes.
1818
1919Classes includes:
20+ TensorFlowCrossEntropyLoss, PyTorchCrossEntropyLoss,
21+ TensorFlowSparseCategoricalCrossentropy,
2022 TensorflowKnowledgeDistillationLoss, PyTorchKnowledgeDistillationLoss,
2123 PyTorchIntermediateLayersKnowledgeDistillationLoss.
2224"""
@@ -91,7 +93,12 @@ def __getitem__(self, criterion_type):
9193 Returns:
9294 cls: criterion class.
9395 """
94- assert criterion_type in self .criterions .keys (), "only support criterions in {}" .format (self .criterions .keys ())
96+ assert (
97+ criterion_type in self .criterions .keys ()
98+ ), "only support criterions in {} \
99+ , but got criterion type {}" .format (
100+ self .criterions .keys (), criterion_type
101+ )
95102
96103 return self .criterions [criterion_type ]
97104
@@ -130,6 +137,119 @@ def decorator_criterion(cls):
130137 return decorator_criterion
131138
132139
140+ @criterion_registry ("CrossEntropyLoss" , "tensorflow" )
141+ class TensorFlowCrossEntropyLoss (object ):
142+ """TensorFlow CrossEntropyLoss criterion."""
143+
144+ def __init__ (self , param_dict ):
145+ """Initialize the Datasets class.
146+
147+ Args:
148+ param_dict (dict): The dict of parameters setting by user for CrossEntropyLoss criterion.
149+ """
150+ assert isinstance (param_dict , dict ), "This criterion constructor parameter must be a dict"
151+ self ._param_dict = param_dict
152+
153+ def _mapping (self ):
154+ _param_map = {"reduction" : "reduction" , "from_logits" : "from_logits" }
155+ _dict = {}
156+ for key in self ._param_dict :
157+ if key in _param_map :
158+ if key == "reduction" :
159+ assert self ._param_dict [key ] in [
160+ "auto" ,
161+ "none" ,
162+ "sum" ,
163+ "sum_over_batch_size" ,
164+ ], "Supported reduction value for tensorflow is auto, none, sum, sum_over_batch_size"
165+ _dict .update ({_param_map [key ]: self ._param_dict [key ]})
166+ return _dict
167+
168+ def __call__ (self ):
169+ """Call the TensorFlowCrossEntropyLoss.
170+
171+ Returns:
172+ cls: criterion class.
173+ param_dict(dict): param_dict
174+ """
175+ return tf .keras .losses .CategoricalCrossentropy , self ._mapping ()
176+
177+
178+ @criterion_registry ("SparseCategoricalCrossentropy" , "tensorflow" )
179+ class TensorFlowSparseCategoricalCrossentropy (object ):
180+ """TensorFlow SparseCategoricalCrossentropyLoss criterion."""
181+
182+ def __init__ (self , param_dict ):
183+ """Initialize the Datasets class.
184+
185+ Args:
186+ param_dict (string): param_dict.
187+ """
188+ assert isinstance (param_dict , dict ), "This criterion constructor parameter must be a dict"
189+ self ._param_dict = param_dict
190+
191+ def _mapping (self ):
192+ _param_map = {"reduction" : "reduction" , "from_logits" : "from_logits" }
193+ _dict = {}
194+ for key in self ._param_dict :
195+ if key in _param_map :
196+ if key == "reduction" :
197+ assert self ._param_dict [key ] in [
198+ "auto" ,
199+ "none" ,
200+ "sum" ,
201+ "sum_over_batch_size" ,
202+ ], "Supported reduction value for tensorflow is auto, none, sum, sum_over_batch_size"
203+ _dict .update ({_param_map [key ]: self ._param_dict [key ]})
204+ return _dict
205+
206+ def __call__ (self ):
207+ """Call the TensorFlowSparseCategoricalCrossentropy.
208+
209+ Returns:
210+ cls: criterion class.
211+ param_dict(dict): param_dict
212+ """
213+ return tf .keras .losses .SparseCategoricalCrossentropy , self ._mapping ()
214+
215+
216+ @criterion_registry ("CrossEntropyLoss" , "pytorch" )
217+ class PyTorchCrossEntropyLoss (object ):
218+ """PyTorch CrossEntropyLoss criterion."""
219+
220+ def __init__ (self , param_dict ):
221+ """Initialize the PyTorchCrossEntropyLoss class.
222+
223+ Args:
224+ param_dict (string): param_dict.
225+ """
226+ assert isinstance (param_dict , dict ), "This criterion constructor parameter must be a dict"
227+ self ._param_dict = param_dict
228+
229+ def _mapping (self ):
230+ _param_map = {"reduction" : "reduction" }
231+ _dict = {}
232+ for key in self ._param_dict :
233+ if key in _param_map :
234+ if key == "reduction" :
235+ assert self ._param_dict [key ] in [
236+ "none" ,
237+ "mean" ,
238+ "sum" ,
239+ ], "Supported reduction value is none, mean, sum"
240+ _dict .update ({_param_map [key ]: self ._param_dict [key ]})
241+ return _dict
242+
243+ def __call__ (self ):
244+ """Call the PyTorchCrossEntropyLoss.
245+
246+ Returns:
247+ cls: criterion class.
248+ param_dict(dict): param_dict
249+ """
250+ return torch .nn .CrossEntropyLoss , self ._mapping ()
251+
252+
133253class KnowledgeDistillationFramework (object ):
134254 """Knowledge Distillation Framework."""
135255
@@ -916,7 +1036,7 @@ def register_hooks_for_models(self):
9161036 Raises:
9171037 AttributeError: AttributeError
9181038 """
919- from neural_compressor .experimental . common import torch_utils
1039+ from neural_compressor .compression . distillation import utility
9201040
9211041 def register_model_forward_hook (model , path , output_process = "" , student = False ):
9221042 module = model
@@ -927,7 +1047,7 @@ def register_model_forward_hook(model, path, output_process="", student=False):
9271047 module = module .__getattr__ (node )
9281048 except :
9291049 raise AttributeError ("There is no path {} in the model." .format (path ))
930- return module .register_forward_hook (torch_utils .get_activation (path , output_process , student ))
1050+ return module .register_forward_hook (utility .get_activation (path , output_process , student ))
9311051
9321052 assert isinstance (self .student_model , torch .nn .Module ) and isinstance (self .teacher_model , torch .nn .Module ), (
9331053 "Expect student_model and teacher_model to be an torch.nn.Module object, "
@@ -939,8 +1059,8 @@ def register_model_forward_hook(model, path, output_process="", student=False):
9391059 student_output_process , teacher_output_process = self .layer_output_process [idx ]
9401060 st_handle = register_model_forward_hook (self .student_model , student_layer , student_output_process , True )
9411061 te_handle = register_model_forward_hook (self .teacher_model , teacher_layer , teacher_output_process )
942- torch_utils .STUDENT_FEATURES = self .student_features
943- torch_utils .TEACHER_FEATURES = self .teacher_features
1062+ utility .STUDENT_FEATURES = self .student_features
1063+ utility .TEACHER_FEATURES = self .teacher_features
9441064 self .hook_handles .extend ([st_handle , te_handle ])
9451065
9461066 def remove_all_hooks (self ):
0 commit comments