From 6a5472a8e1d75e95e8fb4d0bdf8ddf9b237bac03 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 8 Feb 2022 16:20:53 +0100 Subject: [PATCH] Force use_cache to be False in PyTorch (#15385) * use_cache = False for PT models if labels is passed * Fix for BigBirdPegasusForConditionalGeneration * add warning if users specify use_cache=True * Use logger.warning instead of warnings.warn Co-authored-by: ydshieh --- src/transformers/models/bart/modeling_bart.py | 3 +++ .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 3 +++ src/transformers/models/blenderbot/modeling_blenderbot.py | 3 +++ .../models/blenderbot_small/modeling_blenderbot_small.py | 3 +++ src/transformers/models/led/modeling_led.py | 3 +++ src/transformers/models/marian/modeling_marian.py | 3 +++ src/transformers/models/mbart/modeling_mbart.py | 3 +++ src/transformers/models/pegasus/modeling_pegasus.py | 3 +++ .../modeling_{{cookiecutter.lowercase_modelname}}.py | 3 +++ 9 files changed, 27 insertions(+) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 381204cf2d30..70edf96cd2e3 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1318,6 +1318,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index c4db1b8bec5a..aa961e0f599b 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2513,6 +2513,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index df098dd6e195..7751a74f9658 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1287,6 +1287,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 5875a827fab9..a22c4d0ce6cc 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1258,6 +1258,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 8054b9ee6d33..e775fd35c933 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2366,6 +2366,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index f3bd96eeb94f..20cbd21f76e8 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1291,6 +1291,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index fc09f0a7e62f..3e747b4b1ebc 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1314,6 +1314,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 32923ce44d38..5eed41254e80 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1381,6 +1381,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 63233c4bf9ff..69a0d8176683 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -2832,6 +2832,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)