-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BartModel #2745
Add BartModel #2745
Changes from 1 commit
b5c20db
8420dc5
22ccda1
03d2cf3
d99326e
43c7e21
24fb639
61409b4
0b79f39
dcf2b88
92e487f
e0c54ed
3871a7a
dbe83c9
2373e8a
0dda528
69327e4
d630887
3cbc6ca
51ab277
9e694a7
f355e36
3971d97
f42997f
56c4744
26656a0
4d77a7c
0ce724a
38e057f
a48f89e
2ad6e7b
a772509
831fd14
be62f89
6726d33
c0e9510
28bcf61
1f0b885
6aea2b8
4e7279c
28345b4
effa170
8c7df3a
cee5051
586098d
67b02c6
3811209
28c977b
b79509d
5bc3081
df6edc3
5eaade8
edc492e
7c090b0
1d6cde6
1c06538
73cad04
a68c20e
5d1bc99
c23a07b
67ef42f
7a4a6e2
f80ce45
42e061b
28b1f80
4e008e6
60bd737
a4edf2e
e1d106d
a9b979f
4b97345
ed642cc
87ddeae
4628b7d
a653c78
ab594b4
73f49a6
f7d88db
8f04dd5
459aeaf
9ecee5b
aadf762
bac8348
66310db
3f03344
808bbd5
92b5f6e
21ac214
2196cc2
960af22
a812adc
a8a7839
49f60d7
376a358
8ecdd0d
537af62
02b56df
765c98a
4e1a5e0
4339102
3ce6c1e
fd3d991
b22b368
e5c3485
e2827b1
4d49735
ac1657b
2a1260a
afbfdeb
82877e7
8252075
6bacd55
71c345f
67a4cee
6fd50b3
db3bc84
264f6d1
ba25b7a
8f1e8b4
6124967
c01e719
5dfc207
dafdac8
40f7f79
e7ea674
36e1adc
de2ced0
8b5bb52
c2973d4
dbe0f4e
a42ac9c
5faa0dd
6a08f84
c439e19
6205ba6
cda9ced
f3b4f21
85c3b77
5292ab3
e2353c3
16d2e2e
2ede7ab
cb425f3
360db12
3c6f62d
2d69571
de98500
35d421b
9e66bbc
d546db4
3a37397
9b97322
0f2819c
12b83b9
12becba
77578ac
5990cfe
6cff072
e032d06
5592784
0e0b9b1
4a4723e
4a212a2
086b17a
feaf207
2c8225a
300df06
6db143e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ | |
from .configuration_bart import BartConfig | ||
from .file_utils import add_start_docstrings | ||
from .modeling_utils import PreTrainedModel | ||
|
||
from .utils_encoder_decoder import prepare_encoder_decoder_model_kwargs | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -94,15 +94,12 @@ def get_input_embeddings(self): | |
def set_input_embeddings(self, value): | ||
self.shared = value | ||
|
||
def forward(self, input_ids: torch.LongTensor = None, return_for_head=False, **kwargs): | ||
input_ids = input_ids if input_ids is not None else kwargs["encoder_input_ids"] # TODO(SS): decide on API | ||
if input_ids.dim() == 1: | ||
input_ids = input_ids.unsqueeze(0) | ||
if input_ids.size(-1) > min(self.max_positions()): | ||
raise ValueError( | ||
"input_ids exceeds maximum length: {} > {}".format(input_ids.size(-1), self.max_positions()) | ||
) | ||
encoder_out = self.encoder(input_ids) | ||
#def forward(self, input_ids: torch.LongTensor = None, return_for_head=False, **kwargs): | ||
def forward(self,return_for_head=False, **kwargs): | ||
kwargs_encoder, kwargs_decoder = prepare_encoder_decoder_model_kwargs(**kwargs) | ||
# TODO(SS): only call encoder if we need to | ||
encoder_out = self.encoder(**kwargs_encoder) | ||
input_ids = kwargs_encoder.pop('input_ids') | ||
prev_output_tokens = self.shift_tokens_left(input_ids, self.config.pad_token_id) | ||
dec_features, dec_hidden, dec_attn = self.decoder(prev_output_tokens, encoder_out=encoder_out,) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be more documentation here. So we are feeding the same input to both the encoder and the decoder with the decoder input shifted by one token to the left. I feel like this logic is very specific to the pretraining of BART and I'm wondering whether we should have it incorporated in the forward loop or rather as an external pre-processing during a training loop (like with do for other pretraining logic e.g. preparing inputs for MLM models). What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cf my comment on the forward method API: I would maybe recommend moving this logic from the inner model to the specific derived model |
||
if return_for_head: # split encoder and decoder outputs nicely | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -176,20 +173,16 @@ class BartForSequenceClassification(PretrainedBartModel): | |
def __init__(self, config: BartConfig, **kwargs): | ||
super().__init__(config, **kwargs) | ||
self.model = BartModel(config) | ||
self.classification_head = BARTClassificationHead( | ||
self.classification_head = BartClassificationHead( | ||
config.d_model, config.d_model, config.num_labels, config.classif_dropout, | ||
) | ||
self.loss_fn = nn.CrossEntropyLoss() | ||
|
||
def forward(self, input_ids, *args, **kwargs): | ||
|
||
if input_ids.ndim == 1: | ||
input_ids = input_ids.unsqueeze(0) | ||
kwargs["return_for_head"] = True | ||
def forward(self, **kwargs): | ||
labels = kwargs.pop("labels", None) | ||
decoder_outputs, encoder_outputs = self.model(input_ids, *args, **kwargs) | ||
decoder_outputs, encoder_outputs = self.model(return_for_head=True, **kwargs) | ||
x = decoder_outputs[0] # last hidden state | ||
|
||
input_ids = _get_input_ids_from_kwargs(**kwargs) | ||
eos_mask = input_ids.eq(self.eos_token) | ||
if len(torch.unique(eos_mask.sum(1))) > 1: | ||
raise ValueError("All examples must have the same number of <eos> tokens.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Up to now we have avoided mixing model and tokenization logic to let the user more free in particular. we should discuss whether we want to change this philosophy here. cc @julien-c @LysandreJik There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still a firm believer that model logic and tokenization logic should be separate. We used to have a similar warning/error in RoBERTa but removed it because it
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the best way is to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's easier in Roberta to dissociate because we do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, sounds reasonable. |
||
|
@@ -204,6 +197,9 @@ def forward(self, input_ids, *args, **kwargs): | |
return decoder_outputs + encoder_outputs | ||
|
||
|
||
def _get_input_ids_from_kwargs(**kwargs): | ||
"""Try to get input_ids and if that key is not present get encoder_input_ids.""" | ||
return kwargs.get('input_ids', kwargs.get('encoder_input_ids', None)) | ||
# Encoder and Decoder | ||
|
||
|
||
|
@@ -481,33 +477,6 @@ def max_positions(self): | |
"""Maximum input length supported by the encoder.""" | ||
return min(self.max_source_positions, self.embed_positions.max_positions) | ||
|
||
# Unused | ||
def reorder_encoder_out(self, encoder_out, new_order): | ||
""" | ||
Reorder encoder output according to *new_order*. | ||
|
||
Args: | ||
encoder_out: output from the ``forward()`` method | ||
new_order (LongTensor): desired order | ||
|
||
Returns: | ||
*encoder_out* rearranged according to *new_order* | ||
""" | ||
if encoder_out.encoder_out is not None: | ||
encoder_out = encoder_out._replace(encoder_out=encoder_out.encoder_out.index_select(1, new_order)) | ||
if encoder_out.encoder_padding_mask is not None: | ||
encoder_out = encoder_out._replace( | ||
encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order) | ||
) | ||
if encoder_out.encoder_embedding is not None: | ||
encoder_out = encoder_out._replace( | ||
encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order) | ||
) | ||
if encoder_out.encoder_states is not None: | ||
for idx, state in enumerate(encoder_out.encoder_states): | ||
encoder_out.encoder_states[idx] = state.index_select(1, new_order) | ||
return encoder_out | ||
|
||
|
||
class BartDecoder(nn.Module): | ||
""" | ||
|
@@ -644,7 +613,7 @@ def buffered_future_mask(self, tensor): | |
# Helper Modules | ||
|
||
|
||
class BARTClassificationHead(nn.Module): | ||
class BartClassificationHead(nn.Module): | ||
"""Head for sentence-level classification tasks.""" | ||
|
||
# This can trivially be shared with RobertaClassificationHead | ||
|
@@ -920,16 +889,6 @@ def _append_prev_key_padding_mask( | |
new_key_padding_mask = prev_key_padding_mask | ||
return new_key_padding_mask | ||
|
||
def reorder_incremental_state(self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order): | ||
"""Reorder buffered internal state (for incremental generation).""" | ||
# TODO(SS): Where is this used? | ||
input_buffer = self._get_input_buffer(incremental_state) | ||
if input_buffer is not None: | ||
for k in input_buffer.keys(): | ||
if input_buffer[k] is not None: | ||
input_buffer[k] = input_buffer[k].index_select(0, new_order) | ||
self._set_input_buffer(incremental_state, input_buffer) | ||
|
||
def _get_input_buffer( | ||
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] | ||
) -> Dict[str, Optional[Tensor]]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" Classes to support Encoder-Decoder architectures """ | ||
|
||
|
||
def prepare_encoder_decoder_model_kwargs(**kwargs): | ||
""" Prepare the encoder and decoder's keyword arguments. | ||
|
||
Keyword arguments come in 3 flavors: | ||
- encoder-specific (prefixed by `encoder_`) | ||
- decoder-specific (prefixed by `decoder_`) | ||
- those that apply to the model as whole. | ||
|
||
We let the specific kwargs override the common ones in case of | ||
conflict. | ||
""" | ||
kwargs_common = { | ||
argument: value | ||
for argument, value in kwargs.items() | ||
if not argument.startswith("encoder_") and not argument.startswith("decoder_") | ||
} | ||
decoder_kwargs = kwargs_common.copy() | ||
encoder_kwargs = kwargs_common.copy() | ||
encoder_kwargs.update( | ||
{ | ||
argument[len("encoder_") :]: value | ||
for argument, value in kwargs.items() | ||
if argument.startswith("encoder_") | ||
} | ||
) | ||
decoder_kwargs.update( | ||
{ | ||
argument[len("decoder_") :]: value | ||
for argument, value in kwargs.items() | ||
if argument.startswith("decoder_") | ||
} | ||
) | ||
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None) | ||
return encoder_kwargs, decoder_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this name mismatch cause the saved value (saved by
save_pretrained()
) not being loaded to the config by thefrom_pretrained()
method?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no clue what problem you are trying to describe. Please file an issue with a pasteable code snippet that has a different output than you expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok filed #7591