Skip to content

Commit

Permalink
Force use_cache to be False in PyTorch (huggingface#15385)
Browse files Browse the repository at this point in the history
* use_cache = False for PT models if labels is passed

* Fix for BigBirdPegasusForConditionalGeneration

* add warning if users specify use_cache=True

* Use logger.warning instead of warnings.warn

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Feb 8, 2022
1 parent 0acd84f commit 6a5472a
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2513,6 +2513,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/led/modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -2366,6 +2366,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/pegasus/modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2832,6 +2832,9 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

Expand Down

0 comments on commit 6a5472a

Please sign in to comment.