Skip to content

Commit bc0d26d

Browse files
[All Seq2Seq model + CLM models that can be used with EncoderDecoder] Add cross-attention weights to outputs (#8071)
* Output cross-attention with decoder attention output * Update src/transformers/modeling_bert.py * add cross-attention for t5 and bart as well * fix tests * correct typo in docs * add sylvains and sams comments * correct typo Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 30f2507 commit bc0d26d

16 files changed

+653
-85
lines changed

docs/source/main_classes/output.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,34 @@ BaseModelOutputWithPooling
6565
:members:
6666

6767

68+
BaseModelOutputWithCrossAttentions
69+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
70+
71+
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithCrossAttentions
72+
:members:
73+
74+
75+
BaseModelOutputWithPoolingAndCrossAttentions
76+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
77+
78+
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions
79+
:members:
80+
81+
6882
BaseModelOutputWithPast
6983
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7084

7185
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPast
7286
:members:
7387

88+
89+
BaseModelOutputWithPastAndCrossAttentions
90+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91+
92+
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions
93+
:members:
94+
95+
7496
Seq2SeqModelOutput
7597
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7698

@@ -85,6 +107,20 @@ CausalLMOutput
85107
:members:
86108

87109

110+
CausalLMOutputWithCrossAttentions
111+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
112+
113+
.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
114+
:members:
115+
116+
117+
CausalLMOutputWithPastAndCrossAttentions
118+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
119+
120+
.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithPastAndCrossAttentions
121+
:members:
122+
123+
88124
CausalLMOutputWithPast
89125
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
90126

src/transformers/modeling_bart.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636
from .modeling_outputs import (
3737
BaseModelOutput,
38-
BaseModelOutputWithPast,
38+
BaseModelOutputWithPastAndCrossAttentions,
3939
Seq2SeqLMOutput,
4040
Seq2SeqModelOutput,
4141
Seq2SeqQuestionAnsweringModelOutput,
@@ -451,11 +451,12 @@ def forward(
451451
assert self.encoder_attn.cache_key != self.self_attn.cache_key
452452
if self.normalize_before:
453453
x = self.encoder_attn_layer_norm(x)
454-
x, _ = self.encoder_attn(
454+
x, cross_attn_weights = self.encoder_attn(
455455
query=x,
456456
key=encoder_hidden_states,
457457
key_padding_mask=encoder_attn_mask,
458458
layer_state=layer_state, # mutates layer state
459+
output_attentions=output_attentions,
459460
)
460461
x = F.dropout(x, p=self.dropout, training=self.training)
461462
x = residual + x
@@ -477,7 +478,8 @@ def forward(
477478
x,
478479
self_attn_weights,
479480
layer_state,
480-
) # just self_attn weights for now, following t5, layer_state = cache for decoding
481+
cross_attn_weights,
482+
) # layer_state = cache for decoding
481483

482484

483485
class BartDecoder(nn.Module):
@@ -590,6 +592,7 @@ def forward(
590592
# decoder layers
591593
all_hidden_states = () if output_hidden_states else None
592594
all_self_attns = () if output_attentions else None
595+
all_cross_attentions = () if output_attentions else None
593596
next_decoder_cache: List[Dict] = []
594597
for idx, decoder_layer in enumerate(self.layers):
595598
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -601,7 +604,7 @@ def forward(
601604

602605
layer_state = past_key_values[idx] if past_key_values is not None else None
603606

604-
x, layer_self_attn, layer_past = decoder_layer(
607+
x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer(
605608
x,
606609
encoder_hidden_states,
607610
encoder_attn_mask=encoder_padding_mask,
@@ -616,6 +619,7 @@ def forward(
616619

617620
if output_attentions:
618621
all_self_attns += (layer_self_attn,)
622+
all_cross_attentions += (layer_cross_attn,)
619623

620624
if self.layer_norm: # if config.add_final_layer_norm (mBART)
621625
x = self.layer_norm(x)
@@ -628,9 +632,15 @@ def forward(
628632

629633
next_cache = next_decoder_cache if use_cache else None
630634
if not return_dict:
631-
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
632-
return BaseModelOutputWithPast(
633-
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
635+
return tuple(
636+
v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None
637+
)
638+
return BaseModelOutputWithPastAndCrossAttentions(
639+
last_hidden_state=x,
640+
past_key_values=next_cache,
641+
hidden_states=all_hidden_states,
642+
attentions=all_self_attns,
643+
cross_attentions=all_cross_attentions,
634644
)
635645

636646

@@ -934,6 +944,7 @@ def forward(
934944
past_key_values=decoder_outputs.past_key_values,
935945
decoder_hidden_states=decoder_outputs.hidden_states,
936946
decoder_attentions=decoder_outputs.attentions,
947+
cross_attentions=decoder_outputs.cross_attentions,
937948
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
938949
encoder_hidden_states=encoder_outputs.hidden_states,
939950
encoder_attentions=encoder_outputs.attentions,
@@ -1078,6 +1089,7 @@ def forward(
10781089
past_key_values=outputs.past_key_values,
10791090
decoder_hidden_states=outputs.decoder_hidden_states,
10801091
decoder_attentions=outputs.decoder_attentions,
1092+
cross_attentions=outputs.cross_attentions,
10811093
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
10821094
encoder_hidden_states=outputs.encoder_hidden_states,
10831095
encoder_attentions=outputs.encoder_attentions,
@@ -1207,6 +1219,7 @@ def forward(
12071219
past_key_values=outputs.past_key_values,
12081220
decoder_hidden_states=outputs.decoder_hidden_states,
12091221
decoder_attentions=outputs.decoder_attentions,
1222+
cross_attentions=outputs.cross_attentions,
12101223
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
12111224
encoder_hidden_states=outputs.encoder_hidden_states,
12121225
encoder_attentions=outputs.encoder_attentions,
@@ -1317,6 +1330,7 @@ def forward(
13171330
past_key_values=outputs.past_key_values,
13181331
decoder_hidden_states=outputs.decoder_hidden_states,
13191332
decoder_attentions=outputs.decoder_attentions,
1333+
cross_attentions=outputs.cross_attentions,
13201334
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
13211335
encoder_hidden_states=outputs.encoder_hidden_states,
13221336
encoder_attentions=outputs.encoder_attentions,

src/transformers/modeling_bert.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
replace_return_docstrings,
3838
)
3939
from .modeling_outputs import (
40-
BaseModelOutput,
41-
BaseModelOutputWithPooling,
42-
CausalLMOutput,
40+
BaseModelOutputWithCrossAttentions,
41+
BaseModelOutputWithPoolingAndCrossAttentions,
42+
CausalLMOutputWithCrossAttentions,
4343
MaskedLMOutput,
4444
MultipleChoiceModelOutput,
4545
NextSentencePredictorOutput,
@@ -449,7 +449,8 @@ def forward(
449449
return_dict=False,
450450
):
451451
all_hidden_states = () if output_hidden_states else None
452-
all_attentions = () if output_attentions else None
452+
all_self_attentions = () if output_attentions else None
453+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
453454
for i, layer_module in enumerate(self.layer):
454455
if output_hidden_states:
455456
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -483,15 +484,24 @@ def custom_forward(*inputs):
483484
)
484485
hidden_states = layer_outputs[0]
485486
if output_attentions:
486-
all_attentions = all_attentions + (layer_outputs[1],)
487+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
488+
if self.config.add_cross_attention:
489+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
487490

488491
if output_hidden_states:
489492
all_hidden_states = all_hidden_states + (hidden_states,)
490493

491494
if not return_dict:
492-
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
493-
return BaseModelOutput(
494-
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
495+
return tuple(
496+
v
497+
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
498+
if v is not None
499+
)
500+
return BaseModelOutputWithCrossAttentions(
501+
last_hidden_state=hidden_states,
502+
hidden_states=all_hidden_states,
503+
attentions=all_self_attentions,
504+
cross_attentions=all_cross_attentions,
495505
)
496506

497507

@@ -752,7 +762,7 @@ class PreTrainedModel
752762
@add_code_sample_docstrings(
753763
tokenizer_class=_TOKENIZER_FOR_DOC,
754764
checkpoint="bert-base-uncased",
755-
output_type=BaseModelOutputWithPooling,
765+
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
756766
config_class=_CONFIG_FOR_DOC,
757767
)
758768
def forward(
@@ -843,11 +853,12 @@ def forward(
843853
if not return_dict:
844854
return (sequence_output, pooled_output) + encoder_outputs[1:]
845855

846-
return BaseModelOutputWithPooling(
856+
return BaseModelOutputWithPoolingAndCrossAttentions(
847857
last_hidden_state=sequence_output,
848858
pooler_output=pooled_output,
849859
hidden_states=encoder_outputs.hidden_states,
850860
attentions=encoder_outputs.attentions,
861+
cross_attentions=encoder_outputs.cross_attentions,
851862
)
852863

853864

@@ -984,7 +995,7 @@ def get_output_embeddings(self):
984995
return self.cls.predictions.decoder
985996

986997
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
987-
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
998+
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
988999
def forward(
9891000
self,
9901001
input_ids=None,
@@ -1063,11 +1074,12 @@ def forward(
10631074
output = (prediction_scores,) + outputs[2:]
10641075
return ((lm_loss,) + output) if lm_loss is not None else output
10651076

1066-
return CausalLMOutput(
1077+
return CausalLMOutputWithCrossAttentions(
10671078
loss=lm_loss,
10681079
logits=prediction_scores,
10691080
hidden_states=outputs.hidden_states,
10701081
attentions=outputs.attentions,
1082+
cross_attentions=outputs.cross_attentions,
10711083
)
10721084

10731085
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):

src/transformers/modeling_bert_generation.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
replace_return_docstrings,
2929
)
3030
from .modeling_bert import BertEncoder
31-
from .modeling_outputs import BaseModelOutput, CausalLMOutput
31+
from .modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions
3232
from .modeling_utils import PreTrainedModel
3333
from .utils import logging
3434

@@ -297,7 +297,7 @@ class PreTrainedModel
297297
@add_code_sample_docstrings(
298298
tokenizer_class=_TOKENIZER_FOR_DOC,
299299
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
300-
output_type=BaseModelOutput,
300+
output_type=BaseModelOutputWithCrossAttentions,
301301
config_class=_CONFIG_FOR_DOC,
302302
)
303303
def forward(
@@ -381,10 +381,11 @@ def forward(
381381
if not return_dict:
382382
return (sequence_output,) + encoder_outputs[1:]
383383

384-
return BaseModelOutput(
384+
return BaseModelOutputWithCrossAttentions(
385385
last_hidden_state=sequence_output,
386386
hidden_states=encoder_outputs.hidden_states,
387387
attentions=encoder_outputs.attentions,
388+
cross_attentions=encoder_outputs.cross_attentions,
388389
)
389390

390391

@@ -422,7 +423,7 @@ def get_output_embeddings(self):
422423
return self.lm_head.decoder
423424

424425
@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
425-
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
426+
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
426427
def forward(
427428
self,
428429
input_ids=None,
@@ -499,11 +500,12 @@ def forward(
499500
output = (prediction_scores,) + outputs[1:]
500501
return ((lm_loss,) + output) if lm_loss is not None else output
501502

502-
return CausalLMOutput(
503+
return CausalLMOutputWithCrossAttentions(
503504
loss=lm_loss,
504505
logits=prediction_scores,
505506
hidden_states=outputs.hidden_states,
506507
attentions=outputs.attentions,
508+
cross_attentions=outputs.cross_attentions,
507509
)
508510

509511
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):

src/transformers/modeling_electra.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
replace_return_docstrings,
3535
)
3636
from .modeling_outputs import (
37-
BaseModelOutput,
37+
BaseModelOutputWithCrossAttentions,
3838
MaskedLMOutput,
3939
MultipleChoiceModelOutput,
4040
QuestionAnsweringModelOutput,
@@ -445,7 +445,8 @@ def forward(
445445
return_dict=False,
446446
):
447447
all_hidden_states = () if output_hidden_states else None
448-
all_attentions = () if output_attentions else None
448+
all_self_attentions = () if output_attentions else None
449+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
449450
for i, layer_module in enumerate(self.layer):
450451
if output_hidden_states:
451452
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -479,15 +480,24 @@ def custom_forward(*inputs):
479480
)
480481
hidden_states = layer_outputs[0]
481482
if output_attentions:
482-
all_attentions = all_attentions + (layer_outputs[1],)
483+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
484+
if self.config.add_cross_attention:
485+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
483486

484487
if output_hidden_states:
485488
all_hidden_states = all_hidden_states + (hidden_states,)
486489

487490
if not return_dict:
488-
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
489-
return BaseModelOutput(
490-
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
491+
return tuple(
492+
v
493+
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
494+
if v is not None
495+
)
496+
return BaseModelOutputWithCrossAttentions(
497+
last_hidden_state=hidden_states,
498+
hidden_states=all_hidden_states,
499+
attentions=all_self_attentions,
500+
cross_attentions=all_cross_attentions,
491501
)
492502

493503

@@ -697,7 +707,7 @@ class PreTrainedModel
697707
@add_code_sample_docstrings(
698708
tokenizer_class=_TOKENIZER_FOR_DOC,
699709
checkpoint="google/electra-small-discriminator",
700-
output_type=BaseModelOutput,
710+
output_type=BaseModelOutputWithCrossAttentions,
701711
config_class=_CONFIG_FOR_DOC,
702712
)
703713
def forward(

src/transformers/modeling_encoder_decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ def forward(
426426
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
427427
decoder_hidden_states=decoder_outputs.hidden_states,
428428
decoder_attentions=decoder_outputs.attentions,
429+
cross_attentions=decoder_outputs.cross_attentions,
429430
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
430431
encoder_hidden_states=encoder_outputs.hidden_states,
431432
encoder_attentions=encoder_outputs.attentions,

0 commit comments

Comments
 (0)