Skip to content

Commit 83805b2

Browse files
yiliu30pre-commit-ci[bot]
authored andcommitted
Remove 1.x API (1/N) (#1323)
* Update all experimental APIs caller --------- Signed-off-by: yiliu30 <yi4.liu@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: bmyrcha <bartosz.myrcha@intel.com>
1 parent 7b20588 commit 83805b2

File tree

13 files changed

+429
-563
lines changed

13 files changed

+429
-563
lines changed

neural_compressor/adaptor/tensorflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def train(self, model, dataloader, optimizer_tuple, criterion_tuple, hooks, post
190190
callbacks = kwargs["kwargs"].get("callbacks", None)
191191
execution_mode = kwargs["kwargs"].get("execution_mode", None)
192192
distributed = getattr(dataloader, "distributed", False)
193-
from neural_compressor.experimental.common.criterion import TensorflowKnowledgeDistillationLoss
193+
from neural_compressor.compression.distillation.criterions import TensorflowKnowledgeDistillationLoss
194194

195195
if isinstance(criterion, TensorflowKnowledgeDistillationLoss):
196196
input_model = model._model
@@ -1757,8 +1757,8 @@ def _get_mse_order(
17571757

17581758
def _partial_dataset_of(self, dataloader, confidence_batches):
17591759
"""Partial dataset."""
1760+
from neural_compressor.data.datasets.dummy_dataset import DummyDataset
17601761
from neural_compressor.data.datasets.dummy_dataset import DummyDataset as DummyDataset_v2_x
1761-
from neural_compressor.experimental.data.datasets.dummy_dataset import DummyDataset
17621762

17631763
if isinstance(dataloader.dataset, DummyDataset) or isinstance(dataloader.dataset, DummyDataset_v2_x):
17641764
assert isinstance(confidence_batches, int)

neural_compressor/compression/distillation/criterions.py

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"""Initialize critetion classes.
1818
1919
Classes includes:
20+
TensorFlowCrossEntropyLoss, PyTorchCrossEntropyLoss,
21+
TensorFlowSparseCategoricalCrossentropy,
2022
TensorflowKnowledgeDistillationLoss, PyTorchKnowledgeDistillationLoss,
2123
PyTorchIntermediateLayersKnowledgeDistillationLoss.
2224
"""
@@ -91,7 +93,12 @@ def __getitem__(self, criterion_type):
9193
Returns:
9294
cls: criterion class.
9395
"""
94-
assert criterion_type in self.criterions.keys(), "only support criterions in {}".format(self.criterions.keys())
96+
assert (
97+
criterion_type in self.criterions.keys()
98+
), "only support criterions in {} \
99+
, but got criterion type {}".format(
100+
self.criterions.keys(), criterion_type
101+
)
95102

96103
return self.criterions[criterion_type]
97104

@@ -130,6 +137,119 @@ def decorator_criterion(cls):
130137
return decorator_criterion
131138

132139

140+
@criterion_registry("CrossEntropyLoss", "tensorflow")
141+
class TensorFlowCrossEntropyLoss(object):
142+
"""TensorFlow CrossEntropyLoss criterion."""
143+
144+
def __init__(self, param_dict):
145+
"""Initialize the Datasets class.
146+
147+
Args:
148+
param_dict (dict): The dict of parameters setting by user for CrossEntropyLoss criterion.
149+
"""
150+
assert isinstance(param_dict, dict), "This criterion constructor parameter must be a dict"
151+
self._param_dict = param_dict
152+
153+
def _mapping(self):
154+
_param_map = {"reduction": "reduction", "from_logits": "from_logits"}
155+
_dict = {}
156+
for key in self._param_dict:
157+
if key in _param_map:
158+
if key == "reduction":
159+
assert self._param_dict[key] in [
160+
"auto",
161+
"none",
162+
"sum",
163+
"sum_over_batch_size",
164+
], "Supported reduction value for tensorflow is auto, none, sum, sum_over_batch_size"
165+
_dict.update({_param_map[key]: self._param_dict[key]})
166+
return _dict
167+
168+
def __call__(self):
169+
"""Call the TensorFlowCrossEntropyLoss.
170+
171+
Returns:
172+
cls: criterion class.
173+
param_dict(dict): param_dict
174+
"""
175+
return tf.keras.losses.CategoricalCrossentropy, self._mapping()
176+
177+
178+
@criterion_registry("SparseCategoricalCrossentropy", "tensorflow")
179+
class TensorFlowSparseCategoricalCrossentropy(object):
180+
"""TensorFlow SparseCategoricalCrossentropyLoss criterion."""
181+
182+
def __init__(self, param_dict):
183+
"""Initialize the Datasets class.
184+
185+
Args:
186+
param_dict (string): param_dict.
187+
"""
188+
assert isinstance(param_dict, dict), "This criterion constructor parameter must be a dict"
189+
self._param_dict = param_dict
190+
191+
def _mapping(self):
192+
_param_map = {"reduction": "reduction", "from_logits": "from_logits"}
193+
_dict = {}
194+
for key in self._param_dict:
195+
if key in _param_map:
196+
if key == "reduction":
197+
assert self._param_dict[key] in [
198+
"auto",
199+
"none",
200+
"sum",
201+
"sum_over_batch_size",
202+
], "Supported reduction value for tensorflow is auto, none, sum, sum_over_batch_size"
203+
_dict.update({_param_map[key]: self._param_dict[key]})
204+
return _dict
205+
206+
def __call__(self):
207+
"""Call the TensorFlowSparseCategoricalCrossentropy.
208+
209+
Returns:
210+
cls: criterion class.
211+
param_dict(dict): param_dict
212+
"""
213+
return tf.keras.losses.SparseCategoricalCrossentropy, self._mapping()
214+
215+
216+
@criterion_registry("CrossEntropyLoss", "pytorch")
217+
class PyTorchCrossEntropyLoss(object):
218+
"""PyTorch CrossEntropyLoss criterion."""
219+
220+
def __init__(self, param_dict):
221+
"""Initialize the PyTorchCrossEntropyLoss class.
222+
223+
Args:
224+
param_dict (string): param_dict.
225+
"""
226+
assert isinstance(param_dict, dict), "This criterion constructor parameter must be a dict"
227+
self._param_dict = param_dict
228+
229+
def _mapping(self):
230+
_param_map = {"reduction": "reduction"}
231+
_dict = {}
232+
for key in self._param_dict:
233+
if key in _param_map:
234+
if key == "reduction":
235+
assert self._param_dict[key] in [
236+
"none",
237+
"mean",
238+
"sum",
239+
], "Supported reduction value is none, mean, sum"
240+
_dict.update({_param_map[key]: self._param_dict[key]})
241+
return _dict
242+
243+
def __call__(self):
244+
"""Call the PyTorchCrossEntropyLoss.
245+
246+
Returns:
247+
cls: criterion class.
248+
param_dict(dict): param_dict
249+
"""
250+
return torch.nn.CrossEntropyLoss, self._mapping()
251+
252+
133253
class KnowledgeDistillationFramework(object):
134254
"""Knowledge Distillation Framework."""
135255

@@ -916,7 +1036,7 @@ def register_hooks_for_models(self):
9161036
Raises:
9171037
AttributeError: AttributeError
9181038
"""
919-
from neural_compressor.experimental.common import torch_utils
1039+
from neural_compressor.compression.distillation import utility
9201040

9211041
def register_model_forward_hook(model, path, output_process="", student=False):
9221042
module = model
@@ -927,7 +1047,7 @@ def register_model_forward_hook(model, path, output_process="", student=False):
9271047
module = module.__getattr__(node)
9281048
except:
9291049
raise AttributeError("There is no path {} in the model.".format(path))
930-
return module.register_forward_hook(torch_utils.get_activation(path, output_process, student))
1050+
return module.register_forward_hook(utility.get_activation(path, output_process, student))
9311051

9321052
assert isinstance(self.student_model, torch.nn.Module) and isinstance(self.teacher_model, torch.nn.Module), (
9331053
"Expect student_model and teacher_model to be an torch.nn.Module object, "
@@ -939,8 +1059,8 @@ def register_model_forward_hook(model, path, output_process="", student=False):
9391059
student_output_process, teacher_output_process = self.layer_output_process[idx]
9401060
st_handle = register_model_forward_hook(self.student_model, student_layer, student_output_process, True)
9411061
te_handle = register_model_forward_hook(self.teacher_model, teacher_layer, teacher_output_process)
942-
torch_utils.STUDENT_FEATURES = self.student_features
943-
torch_utils.TEACHER_FEATURES = self.teacher_features
1062+
utility.STUDENT_FEATURES = self.student_features
1063+
utility.TEACHER_FEATURES = self.teacher_features
9441064
self.hook_handles.extend([st_handle, te_handle])
9451065

9461066
def remove_all_hooks(self):

0 commit comments

Comments
 (0)