Skip to content

Commit

Permalink
[LED] fix global_attention_mask not being passed for generation and d…
Browse files Browse the repository at this point in the history
…ocs clarification about grad checkpointing (#17112)

* [LED] fixed global_attention_mask not passed for generation + docs clarification for gradient checkpointing

* LED docs clarification

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [LED] gradient_checkpointing=True should be passed to TrainingArguments

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [LED] docs: remove wrong word

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [LED] docs fix typo

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
caesar-one and patrickvonplaten authored May 17, 2022
1 parent bad3583 commit d9050dc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 4 additions & 2 deletions docs/source/en/model_doc/led.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ Tips:
- LED makes use of *global attention* by means of the `global_attention_mask` (see
[`LongformerModel`]). For summarization, it is advised to put *global attention* only on the first
`<s>` token. For question answering, it is advised to put *global attention* on all tokens of the question.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing
`model.gradient_checkpointing_enable()`.
- To fine-tune LED on all 16384, *gradient checkpointing* can be enabled in case training leads to out-of-memory (OOM)
errors. This can be done by executing `model.gradient_checkpointing_enable()`.
Moreover, the `use_cache=False`
flag can be used to disable the caching mechanism to save memory.
- A notebook showing how to evaluate LED, can be accessed [here](https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing).
- A notebook showing how to fine-tune LED, can be accessed [here](https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing).

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/led/modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -2441,6 +2441,7 @@ def prepare_inputs_for_generation(
decoder_input_ids,
past=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
Expand All @@ -2458,6 +2459,7 @@ def prepare_inputs_for_generation(
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"global_attention_mask": global_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
Expand Down

0 comments on commit d9050dc

Please sign in to comment.