-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated patch for NNCF integration to the latest transformers release #1328
Updated patch for NNCF integration to the latest transformers release #1328
Conversation
- default_label_names = find_labels(self.model.__class__) | ||
+ model_class = self.model.__class__ | ||
+ if isinstance(self.model, NNCFNetwork): | ||
+ model_class = self.model.get_nncf_wrapped_model().__class__ | ||
+ default_label_names = find_labels(model_class) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In new version labels are defined based on model signature, that's why need to unwrap NNCFNetwork
model.zero_grad() | ||
self.state.global_step += 1 | ||
self.state.epoch = epoch + (step + 1) / steps_in_epoch | ||
+ self.state.curr_loss = curr_loss.cpu().detach().item() | ||
+ self.state.curr_loss = tr_loss_step.cpu().detach().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tr_loss
and tr_loss_step
are explicitly separated
+++ b/src/transformers/utils/__init__.py | ||
@@ -154,6 +154,7 @@ from .import_utils import ( | ||
|
||
|
||
WEIGHTS_NAME = "pytorch_model.bin" | ||
+NNCF_PT_STATE_NAME = "nncf_state.bin" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
file_utils.py
just does backward compatibility imports
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES | ||
|
||
WEIGHTS_NAME = "pytorch_model.bin" | ||
+NNCF_PT_STATE_NAME = "nncf_state.bin" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was moved to utils/__init__.py
@@ -1205,6 +1146,9 @@ index 000000000..edbe0f84d | |||
+ "initializer": { | |||
+ "range": { | |||
+ "num_init_samples": 24 | |||
+ }, | |||
+ "batchnorm_adaptation": { | |||
+ "num_bn_adaptation_samples": 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BatchNorm
adaptation is useless for transformers, because there's no BN layers. It can be turned off by specifying zero number of samples only. It saves some validation time, at least.
# Model has labels -> use them. | ||
- if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: | ||
- if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)): | ||
+ if config.label2id != PretrainedConfig(num_labels=num_labels).label2id: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Config is passed to model later.
- AutoConfig, | ||
- AutoModelForTokenClassification, | ||
- AutoTokenizer, | ||
- DataCollatorForTokenClassification, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just haven't reordered imports for the consistency with other samples.
@@ -204,7 +204,8 @@ def test_glue_distilbert_eval(self, temp_folder): | |||
|
|||
@pytest.mark.dependency(depends=['install_trans'], name='lm_train') | |||
def test_lm_train(self, temp_folder): | |||
com_line = "examples/pytorch/language-modeling/run_clm.py --model_name_or_path gpt2" \ | |||
# GPT2 is loaded via torch.frombuffer which is not available in torch==1.9.1 yet | |||
com_line = "examples/pytorch/language-modeling/run_clm.py --model_name_or_path distilgpt2" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPT2 is loaded via safetensor
functionality that uses torch.frombuffer
. This function appears since torch1.10, but we support torch==1.9.1 only.
@@ -21,7 +21,7 @@ | |||
from tests.common.helpers import PROJECT_ROOT | |||
from tests.torch.helpers import Command | |||
|
|||
TRANSFORMERS_COMMIT = "bff1c71e84e392af9625c345f9ea71f7b6d75fb3" | |||
TRANSFORMERS_COMMIT = "bd469c40659ce76c81f69c7726759d249b4aef49" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v4.23.1
@vshampor 3rd party sanity tests are green (Build 23) |
Changes
Updated patch for NNCF integration to the latest transformers release (v4.23.1)
Reason for changes
for smoother integration of JPQD #1319
Related tickets
94449
Tests
3rd party sanity