Skip to content

Commit

Permalink
Fixed pylint error (intel#18)
Browse files Browse the repository at this point in the history
* Fixed pylint error
  • Loading branch information
PenghuiCheng authored Apr 7, 2022
1 parent 00b472a commit e2d4176
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
datasets >= 1.8.0
torch >= 1.10.0
transformers>=4.12.0
wandb
wandb
6 changes: 3 additions & 3 deletions nlp_toolkit/optimization/auto_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def train_evaluate(self, model):
Returns:
Evaluated metrics of the model.
"""
assert self.train_func is not None and self.eval_func is not None, \
assert self._train_func is not None and self._eval_func is not None, \
"train_func and eval_func must be set."
model = self.train_func(model)
return self.eval_func(model)
model = self._train_func(model)
return self._eval_func(model)

def metrics_conversion(self, metrics):
if isinstance(metrics, dict):
Expand Down
2 changes: 1 addition & 1 deletion nlp_toolkit/optimization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def __init__(
metrics: Union[List, Metric] = None,
):
super().__init__()
from nncf import NNCFConfig
from nncf import NNCFConfig # disable=E0401
assert isinstance(nncf_config, NNCFConfig)
self.nncf_config = nncf_config
if metrics is not None:
Expand Down
44 changes: 31 additions & 13 deletions nlp_toolkit/optimization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def from_pretrained(
("resume_download", False),
("revision", None),
]
download_kwargs = {name: kwargs.get(name, default_value) for (name, default_value) in download_kwarg_default}
download_kwargs = {
name: kwargs.get(name, default_value)
for (name, default_value) in download_kwarg_default
}

config = AutoConfig.from_pretrained(model_name_or_path)
model_class = eval(f'transformers.{config.architectures[0]}')
Expand All @@ -82,7 +85,8 @@ def from_pretrained(
keys_to_ignore_on_load_unexpected = copy.deepcopy(
getattr(model_class, "_keys_to_ignore_on_load_unexpected", None)
)
keys_to_ignore_on_load_missing = copy.deepcopy(getattr(model_class, "_keys_to_ignore_on_load_missing", None))
keys_to_ignore_on_load_missing = \
copy.deepcopy(getattr(model_class, "_keys_to_ignore_on_load_missing", None))

# Avoid unnecessary warnings resulting from quantized model initialization
quantized_keys_to_ignore_on_load = [r"zero_point", r"scale",
Expand All @@ -91,7 +95,9 @@ def from_pretrained(
if keys_to_ignore_on_load_unexpected is None:
model_class._keys_to_ignore_on_load_unexpected = quantized_keys_to_ignore_on_load
else:
model_class._keys_to_ignore_on_load_unexpected.extend(quantized_keys_to_ignore_on_load)
model_class._keys_to_ignore_on_load_unexpected.extend(
quantized_keys_to_ignore_on_load
)
missing_keys_to_ignore_on_load = [r"weight", r"bias"]
if keys_to_ignore_on_load_missing is None:
model_class._keys_to_ignore_on_load_missing = missing_keys_to_ignore_on_load
Expand All @@ -104,7 +110,9 @@ def from_pretrained(
model_class._keys_to_ignore_on_load_missing = keys_to_ignore_on_load_missing

if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path):
config_file = hf_bucket_url(model_name_or_path, filename="best_configure.yaml", revision=download_kwargs["revision]"])
config_file = hf_bucket_url(model_name_or_path,
filename="best_configure.yaml",
revision=download_kwargs["revision]"])

try:
resolved_config_file = cached_path(
Expand All @@ -117,19 +125,25 @@ def from_pretrained(
logger.error(err)
msg = (
f"Can't load config for '{model_name_or_path}'. Make sure that:\n\n"
f"-'{model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"-or '{model_name_or_path}' is a correct path to a directory containing a best_configure.yaml file\n\n"
f"-'{model_name_or_path}' is a correct model identifier listed on "
f"'https://huggingface.co/models'\n\n"
f"-or '{model_name_or_path}' is a correct path to a directory containing "
f"a best_configure.yaml file\n\n"
)

if download_kwargs["revision]"] is not None:
msg += (
f"- or {download_kwargs['revision']} is a valid git identifier (branch name, a tag name, or a commit id) that "
f"exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
f"- or {download_kwargs['revision']} is a valid git "
f"identifier (branch name, a tag name, or a commit id) that "
f"exists for this model name as listed on its model page on "
f"'https://huggingface.co/models'\n\n"
)

raise EnvironmentError(msg)

config_file = hf_bucket_url(model_name_or_path, filename="best_model_weights.pt", revision=download_kwargs["revision]"])
config_file = hf_bucket_url(model_name_or_path,
filename="best_model_weights.pt",
revision=download_kwargs["revision]"])
try:
resolved_config_file = cached_path(
config_file,
Expand All @@ -141,14 +155,18 @@ def from_pretrained(
logger.error(err)
msg = (
f"Can't load config for '{model_name_or_path}'. Make sure that:\n\n"
f"-'{model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"-or '{model_name_or_path}' is a correct path to a directory containing a best_model_weights.pt file\n\n"
f"-'{model_name_or_path}' is a correct model identifier listed on "
f"'https://huggingface.co/models'\n\n"
f"-or '{model_name_or_path}' is a correct path to a directory containing "
f"a best_model_weights.pt file\n\n"
)

if download_kwargs["revision]"] is not None:
msg += (
f"- or {download_kwargs['revision']} is a valid git identifier (branch name, a tag name, or a commit id) that "
f"exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
f"- or {download_kwargs['revision']} is a valid git identifier "
f"(branch name, a tag name, or a commit id) that "
f"exists for this model name as listed on its model page on "
f"'https://huggingface.co/models'\n\n"
)

raise EnvironmentError(msg)
Expand Down
30 changes: 11 additions & 19 deletions nlp_toolkit/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import os
import torch

from neural_compressor.experimental import common, Component, Distillation
Expand Down Expand Up @@ -100,6 +101,10 @@ def __init__(
self._train_func = None
self._calib_dataloader = None
self.output_dir = output_dir
self.quant_config = None
self.pruning_config = None
self.distillation_config = None
self._provider = Provider.INC.value

@property
def eval_func(self):
Expand All @@ -125,7 +130,7 @@ def train_func(self, func: Callable):
def calib_dataloader(self, dataloader):
self._calib_dataloader = dataloader

def _init_quantize(self):
def _init_quantizer(self):
from .quantization import QuantizationMode
from neural_compressor.experimental import Quantization, common
assert isinstance(self.quant_config, QuantizationConfig), \
Expand Down Expand Up @@ -157,7 +162,7 @@ def _init_quantize(self):
self.quantizer = quantizer
return quantizer

def _nncf_quantize(self):
def _nncf_quantize(self): # disable=E0401
from nlp_toolkit import NncfConfig
from nncf import create_compressed_model
assert isinstance(self.quant_config, NncfConfig), \
Expand All @@ -175,7 +180,8 @@ def _nncf_quantize(self):
)

self.compression_ctrl = \
compression_algo_controller.distributed() if self.quant_config.distributed else compression_algo_controller
compression_algo_controller.distributed() if self.quant_config.distributed \
else compression_algo_controller

def quantize(
self,
Expand All @@ -197,9 +203,9 @@ def quantize(
self._calib_dataloader = calib_dataloader
quantizer = self._init_quantizer()
opt_model = quantizer.fit()
opt_model.save(self.args.output_dir)
opt_model.save(self.output_dir)
logger.info(
"quantized model and configure file have saved to {}".format(self.args.output_dir)
"quantized model and configure file have saved to {}".format(self.output_dir)
)
return opt_model.model

Expand All @@ -209,20 +215,6 @@ def _init_pruner(self):
assert isinstance(self.pruning_config, PruningConfig), \
"please pass a instance of PruningConfig to NoTrainerOptimizer.prune!"

pruning_start_epoch, pruning_end_epoch = self.pruning_config.epoch_range

if pruning_start_epoch > self.args.num_train_epochs - 1:
logger.warning(
f"Pruning end epoch {pruning_start_epoch} is higher than the total number of training epoch "
f"{self.args.num_train_epochs}. No pruning will be applied."
)

if pruning_end_epoch > self.args.num_train_epochs - 1:
logger.warning(
f"Pruning end epoch {pruning_end_epoch} is higher than the total number of training epoch "
f"{self.args.num_train_epochs}. The target sparsity will not be reached."
)

pruner = Pruning(self.pruning_config.inc_config)
pruner.model = common.Model(self.model)

Expand Down
20 changes: 12 additions & 8 deletions nlp_toolkit/optimization/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@
import datasets

if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_model as xm # disable=E0401

if is_apex_available():
from apex import amp
from apex import amp # disable=E0401

if is_sagemaker_mp_enabled():
from .trainer_pt_utils import smp_forward_backward
from .trainer_pt_utils import smp_forward_backward # disable=E0401

if TYPE_CHECKING:
import optuna
Expand Down Expand Up @@ -189,7 +189,7 @@ def _init_quantizer(self):
self.quantizer = quantizer
return quantizer

def _nncf_quantize(self):
def _nncf_quantize(self): # disable=E0401
from nlp_toolkit import NncfConfig
from nncf import create_compressed_model
compression_state = None
Expand Down Expand Up @@ -272,13 +272,15 @@ def _init_pruner(self):

if pruning_start_epoch > self.args.num_train_epochs - 1:
logger.warning(
f"Pruning end epoch {pruning_start_epoch} is higher than the total number of training epoch "
f"Pruning end epoch {pruning_start_epoch} is higher than "
f"the total number of training epoch "
f"{self.args.num_train_epochs}. No pruning will be applied."
)

if pruning_end_epoch > self.args.num_train_epochs - 1:
logger.warning(
f"Pruning end epoch {pruning_end_epoch} is higher than the total number of training epoch "
f"Pruning end epoch {pruning_end_epoch} is higher than "
f"the total number of training epoch "
f"{self.args.num_train_epochs}. The target sparsity will not be reached."
)

Expand Down Expand Up @@ -331,7 +333,8 @@ def _init_distiller(self):
if self._eval_func is not None:
distiller.eval_func = self._eval_func
else:
assert self.metrics is not None, "Please pass metrics to trainer.distillation.metrics!"
assert self.metrics is not None, \
"Please pass metrics to trainer.distillation.metrics!"
distiller.eval_func = self.builtin_eval_func

distiller.train_func = \
Expand Down Expand Up @@ -861,7 +864,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device)

if self.use_amp:
if self.use_amp: # disable=E0401
from torch.cuda.amp import autocast
with autocast():
loss = self.compute_loss(model, inputs)
else:
Expand Down
4 changes: 2 additions & 2 deletions nlp_toolkit/preprocessing/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def text_generation_augmentation(self, extension, raw_datasets):
min_length = m - std

total_count = sum(label2count.values())
factor = total_count \
if num_return_sentences <= 0 else int(math.ceil(num_return_sentences / num_samples))
factor = total_count if num_return_sentences <= 0 \
else int(math.ceil(num_return_sentences / self._num_samples))
p0 = label2count[0] / total_count
p1 = 1 - p0

Expand Down

0 comments on commit e2d4176

Please sign in to comment.