Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BartModel #2745

Merged
merged 168 commits into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
168 commits
Select commit Hold shift + click to select a range
b5c20db
3 new files
sshleifer Jan 23, 2020
8420dc5
Lots of fairseq copy paste
sshleifer Jan 24, 2020
22ccda1
typo idiocy
sshleifer Jan 24, 2020
03d2cf3
Copy paste code that we know we wont use
sshleifer Jan 24, 2020
d99326e
before consider Roberta way
sshleifer Jan 24, 2020
43c7e21
add tokenization: identical to Roberta
sshleifer Jan 25, 2020
24fb639
register in configuration auto
sshleifer Jan 25, 2020
61409b4
mid consolidation of fairseq heirarchy
sshleifer Jan 26, 2020
0b79f39
Forward works, but shapes are wrong
sshleifer Jan 27, 2020
dcf2b88
copy pasted tests
sshleifer Jan 27, 2020
92e487f
matching state dict after upgrade
sshleifer Jan 30, 2020
e0c54ed
Merge remote-tracking branch 'upstream/master' into bart
sshleifer Jan 31, 2020
3871a7a
rm typo
sshleifer Jan 31, 2020
dbe83c9
del maybe layernorm
sshleifer Feb 2, 2020
2373e8a
Delete more maybe_layer_norm
sshleifer Feb 2, 2020
0dda528
Moved code round
sshleifer Feb 2, 2020
69327e4
fixed base tests, some notimpl for attention module
sshleifer Feb 3, 2020
d630887
config cleanup
sshleifer Feb 3, 2020
3cbc6ca
Some test cleanup
sshleifer Feb 4, 2020
51ab277
Merge branch 'master' into bart
sshleifer Feb 4, 2020
9e694a7
fixed attn weights shape failure with big copy paste
sshleifer Feb 4, 2020
f355e36
passing hidden_states shape test
sshleifer Feb 4, 2020
3971d97
initializer_factor, passing more tests
sshleifer Feb 5, 2020
f42997f
Merge remote-tracking branch 'upstream/master' into bart
sshleifer Feb 5, 2020
56c4744
utests pass
sshleifer Feb 5, 2020
26656a0
del unused file
sshleifer Feb 5, 2020
4d77a7c
whitespace
sshleifer Feb 5, 2020
0ce724a
Merge remote-tracking branch 'upstream/master' into bart
sshleifer Feb 5, 2020
38e057f
trailing comma
sshleifer Feb 5, 2020
a48f89e
make style, quality
sshleifer Feb 5, 2020
2ad6e7b
remove fairseq dep
sshleifer Feb 5, 2020
a772509
Undo error type change
sshleifer Feb 5, 2020
831fd14
whitespace
sshleifer Feb 5, 2020
be62f89
fix fstring
sshleifer Feb 5, 2020
6726d33
black
sshleifer Feb 5, 2020
c0e9510
type hinting only in comments for py35
sshleifer Feb 5, 2020
28bcf61
more style
sshleifer Feb 5, 2020
1f0b885
isort
sshleifer Feb 5, 2020
6aea2b8
fix NameError
sshleifer Feb 5, 2020
4e7279c
del methods
sshleifer Feb 5, 2020
28345b4
del methods
sshleifer Feb 5, 2020
effa170
F.gelu
sshleifer Feb 5, 2020
8c7df3a
small
sshleifer Feb 5, 2020
cee5051
style
sshleifer Feb 5, 2020
586098d
Working conversion script
sshleifer Feb 6, 2020
67b02c6
cleaning
sshleifer Feb 6, 2020
3811209
test init more directly
sshleifer Feb 6, 2020
28c977b
more variance checks
sshleifer Feb 6, 2020
b79509d
hardcoding expected results
sshleifer Feb 6, 2020
5bc3081
undo stupid change
sshleifer Feb 6, 2020
df6edc3
idiot
sshleifer Feb 6, 2020
5eaade8
delete torch version
sshleifer Feb 6, 2020
edc492e
cleanup, passing
sshleifer Feb 6, 2020
7c090b0
cleanup, passing
sshleifer Feb 6, 2020
1d6cde6
passing
sshleifer Feb 6, 2020
1c06538
Style
sshleifer Feb 6, 2020
73cad04
more deletion
sshleifer Feb 6, 2020
a68c20e
cleanup style, passing
sshleifer Feb 6, 2020
5d1bc99
resize_embeddings test passing
sshleifer Feb 6, 2020
c23a07b
AutoTokenizer support
sshleifer Feb 7, 2020
67ef42f
one file
sshleifer Feb 7, 2020
7a4a6e2
Fix class ordering
sshleifer Feb 7, 2020
f80ce45
conversion broken
sshleifer Feb 7, 2020
42e061b
some old changes
sshleifer Feb 7, 2020
28b1f80
conversion scripts work
sshleifer Feb 7, 2020
4e008e6
fix s3 linking
sshleifer Feb 7, 2020
60bd737
no cnn model
sshleifer Feb 7, 2020
a4edf2e
no cnn
sshleifer Feb 7, 2020
e1d106d
One kwarg for encoder_decoder_attention
sshleifer Feb 8, 2020
a9b979f
cleanup
sshleifer Feb 8, 2020
4b97345
BROKEN
sshleifer Feb 8, 2020
ed642cc
half fixed
sshleifer Feb 8, 2020
87ddeae
Merge remote-tracking branch 'upstream/master' into bart
sshleifer Feb 8, 2020
4628b7d
hoist split_kwargs
sshleifer Feb 9, 2020
a653c78
split out testing BartForSequenceClassification
sshleifer Feb 9, 2020
ab594b4
cleanup
sshleifer Feb 9, 2020
73f49a6
lmhead test passing
sshleifer Feb 9, 2020
f7d88db
calc loss in SeqClassification model
sshleifer Feb 9, 2020
8f04dd5
Fix newlines
sshleifer Feb 9, 2020
459aeaf
ci
sshleifer Feb 9, 2020
9ecee5b
comment public API
sshleifer Feb 10, 2020
aadf762
comments
sshleifer Feb 10, 2020
bac8348
reverted API changes
sshleifer Feb 10, 2020
66310db
style
sshleifer Feb 10, 2020
3f03344
isort
sshleifer Feb 10, 2020
808bbd5
Revert "isort"
sshleifer Feb 10, 2020
92b5f6e
some cleanup
sshleifer Feb 10, 2020
21ac214
Merge remote-tracking branch 'upstream/master' into bart
sshleifer Feb 10, 2020
2196cc2
Sty works here
sshleifer Feb 10, 2020
960af22
cleanup
sshleifer Feb 10, 2020
a812adc
fix slow tests
sshleifer Feb 10, 2020
a8a7839
long
sshleifer Feb 10, 2020
49f60d7
cleanup
sshleifer Feb 10, 2020
376a358
cleanup
sshleifer Feb 10, 2020
8ecdd0d
rename BartForMaskedLM
sshleifer Feb 10, 2020
537af62
use masked loss
sshleifer Feb 10, 2020
02b56df
Merge remote-tracking branch 'upstream/master' into bart
sshleifer Feb 10, 2020
765c98a
Merge remote-tracking branch 'upstream/master' into bart
sshleifer Feb 11, 2020
4e1a5e0
Factor in RobertaTokenizer changes
sshleifer Feb 12, 2020
4339102
delete reorder_ functions
sshleifer Feb 12, 2020
3ce6c1e
pop ignore keys
sshleifer Feb 12, 2020
fd3d991
Fix S3 URLs
sshleifer Feb 12, 2020
b22b368
Conform to t5 API
sshleifer Feb 12, 2020
e5c3485
no head passing
sshleifer Feb 12, 2020
e2827b1
mnli passing
sshleifer Feb 13, 2020
4d49735
Generate works, other stuff broken
sshleifer Feb 13, 2020
ac1657b
Only test inputs embeds fails
sshleifer Feb 13, 2020
2a1260a
caching might work
sshleifer Feb 13, 2020
afbfdeb
Merge remote-tracking branch 'upstream/master' into bart
sshleifer Feb 13, 2020
82877e7
Merge remote-tracking branch 'upstream/master' into bart
sshleifer Feb 13, 2020
8252075
Naming changes, tests pass besides embeds
sshleifer Feb 13, 2020
6bacd55
Dont support inputs embeds
sshleifer Feb 13, 2020
71c345f
New signatures, mnli passing
sshleifer Feb 13, 2020
67a4cee
MNLI PASSING, still two masks
sshleifer Feb 13, 2020
6fd50b3
Padding test passes
sshleifer Feb 13, 2020
db3bc84
One combined attn mask
sshleifer Feb 14, 2020
264f6d1
temp revert
sshleifer Feb 14, 2020
ba25b7a
Passing shape test
sshleifer Feb 14, 2020
8f1e8b4
Style
sshleifer Feb 14, 2020
6124967
passing
sshleifer Feb 14, 2020
c01e719
cleanup
sshleifer Feb 14, 2020
5dfc207
test_shift_tokens_right
sshleifer Feb 14, 2020
dafdac8
Move public API to bottom of file
sshleifer Feb 16, 2020
40f7f79
cleanup return types
sshleifer Feb 16, 2020
e7ea674
Share create_position_ids_from_input_ids with roberta
sshleifer Feb 16, 2020
36e1adc
Initialize SequenceClassification correctly
sshleifer Feb 16, 2020
de2ced0
working. About to hoist inputs
sshleifer Feb 16, 2020
8b5bb52
tests pass with new API
sshleifer Feb 16, 2020
c2973d4
py35 compat: type hint in comment
sshleifer Feb 17, 2020
dbe0f4e
more require_torch
sshleifer Feb 17, 2020
a42ac9c
Fix merge conflict
sshleifer Feb 17, 2020
5faa0dd
Redo cached_states rename
sshleifer Feb 18, 2020
6a08f84
Make masks if user doesnt supply. Passing.
sshleifer Feb 18, 2020
c439e19
style
sshleifer Feb 18, 2020
6205ba6
Delete epically slow test
sshleifer Feb 18, 2020
cda9ced
style
sshleifer Feb 18, 2020
f3b4f21
start docs
sshleifer Feb 18, 2020
85c3b77
More docs
sshleifer Feb 18, 2020
5292ab3
style
sshleifer Feb 18, 2020
e2353c3
style
sshleifer Feb 18, 2020
16d2e2e
more docs
sshleifer Feb 18, 2020
2ede7ab
sty
sshleifer Feb 18, 2020
cb425f3
passing
sshleifer Feb 18, 2020
360db12
passing
sshleifer Feb 18, 2020
3c6f62d
style
sshleifer Feb 18, 2020
2d69571
passing
sshleifer Feb 18, 2020
de98500
some attention cleanup
sshleifer Feb 18, 2020
35d421b
docstrings
sshleifer Feb 18, 2020
9e66bbc
More coverage
sshleifer Feb 18, 2020
d546db4
More test coverage (test_chg branch)
sshleifer Feb 18, 2020
3a37397
kill dead
sshleifer Feb 18, 2020
9b97322
Failing tokenizer test
sshleifer Feb 19, 2020
0f2819c
some docs
sshleifer Feb 19, 2020
12b83b9
Docs work, but are innacurate
sshleifer Feb 19, 2020
12becba
newlne
sshleifer Feb 20, 2020
77578ac
merge upstream
sshleifer Feb 20, 2020
5990cfe
Style
sshleifer Feb 20, 2020
6cff072
Fix decoder_attention_mask test
sshleifer Feb 20, 2020
e032d06
Adopt roberta behavior
sshleifer Feb 20, 2020
5592784
lower tolerance tests
sshleifer Feb 20, 2020
0e0b9b1
Tests passing
sshleifer Feb 20, 2020
4a4723e
Delete input_prep test, its trivial
sshleifer Feb 20, 2020
4a212a2
test passing in mask
sshleifer Feb 20, 2020
086b17a
more coverage
sshleifer Feb 20, 2020
feaf207
Merge remote-tracking branch 'upstream/master' into bart
sshleifer Feb 20, 2020
2c8225a
Docs accurate
sshleifer Feb 20, 2020
300df06
revert generation change
sshleifer Feb 20, 2020
6db143e
improved docs
sshleifer Feb 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
delete reorder_ functions
  • Loading branch information
sshleifer committed Feb 12, 2020
commit 43391026541888bbf766448e2a631639c2558aa1
2 changes: 0 additions & 2 deletions src/transformers/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
max_position_embeddings=1024,
init_std=0.02,
classifier_dropout=0.0,
num_labels=3,
**common_kwargs
):
super().__init__(**common_kwargs)
Expand All @@ -81,7 +80,6 @@ def __init__(

# Classifier stuff
self.classif_dropout = classifier_dropout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this name mismatch cause the saved value (saved by save_pretrained()) not being loaded to the config by the from_pretrained() method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no clue what problem you are trying to describe. Please file an issue with a pasteable code snippet that has a different output than you expected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok filed #7591

self.num_labels = num_labels

@property
def num_attention_heads(self):
Expand Down
71 changes: 15 additions & 56 deletions src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings
from .modeling_utils import PreTrainedModel

from .utils_encoder_decoder import prepare_encoder_decoder_model_kwargs

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,15 +94,12 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.shared = value

def forward(self, input_ids: torch.LongTensor = None, return_for_head=False, **kwargs):
input_ids = input_ids if input_ids is not None else kwargs["encoder_input_ids"] # TODO(SS): decide on API
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
if input_ids.size(-1) > min(self.max_positions()):
raise ValueError(
"input_ids exceeds maximum length: {} > {}".format(input_ids.size(-1), self.max_positions())
)
encoder_out = self.encoder(input_ids)
#def forward(self, input_ids: torch.LongTensor = None, return_for_head=False, **kwargs):
def forward(self,return_for_head=False, **kwargs):
kwargs_encoder, kwargs_decoder = prepare_encoder_decoder_model_kwargs(**kwargs)
# TODO(SS): only call encoder if we need to
encoder_out = self.encoder(**kwargs_encoder)
input_ids = kwargs_encoder.pop('input_ids')
prev_output_tokens = self.shift_tokens_left(input_ids, self.config.pad_token_id)
dec_features, dec_hidden, dec_attn = self.decoder(prev_output_tokens, encoder_out=encoder_out,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be more documentation here. So we are feeding the same input to both the encoder and the decoder with the decoder input shifted by one token to the left.

I feel like this logic is very specific to the pretraining of BART and I'm wondering whether we should have it incorporated in the forward loop or rather as an external pre-processing during a training loop (like with do for other pretraining logic e.g. preparing inputs for MLM models).

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cf my comment on the forward method API: I would maybe recommend moving this logic from the inner model to the specific derived model BertForSequenceClassification()

if return_for_head: # split encoder and decoder outputs nicely
Expand Down Expand Up @@ -176,20 +173,16 @@ class BartForSequenceClassification(PretrainedBartModel):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
self.classification_head = BARTClassificationHead(
self.classification_head = BartClassificationHead(
config.d_model, config.d_model, config.num_labels, config.classif_dropout,
)
self.loss_fn = nn.CrossEntropyLoss()

def forward(self, input_ids, *args, **kwargs):

if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
kwargs["return_for_head"] = True
def forward(self, **kwargs):
labels = kwargs.pop("labels", None)
decoder_outputs, encoder_outputs = self.model(input_ids, *args, **kwargs)
decoder_outputs, encoder_outputs = self.model(return_for_head=True, **kwargs)
x = decoder_outputs[0] # last hidden state

input_ids = _get_input_ids_from_kwargs(**kwargs)
eos_mask = input_ids.eq(self.eos_token)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to now we have avoided mixing model and tokenization logic to let the user more free in particular. we should discuss whether we want to change this philosophy here. cc @julien-c @LysandreJik

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still a firm believer that model logic and tokenization logic should be separate. We used to have a similar warning/error in RoBERTa but removed it because it

  1. made it impossible to trace the method, therefore not usable on TPU
  2. dissociated it from the traditional tokenizer logic =/= model logic that is used in all other models

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best way is to use SequenceSummary(cls_index). Still might need the eos token to make the cls_index tensor, if I'm understanding it correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's easier in Roberta to dissociate because we dox = features[:, 0, :] # take <s> token (equiv. to [CLS])
We can't make a similar assumption about <eos> because of padding tokens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sounds reasonable.

Expand All @@ -204,6 +197,9 @@ def forward(self, input_ids, *args, **kwargs):
return decoder_outputs + encoder_outputs


def _get_input_ids_from_kwargs(**kwargs):
"""Try to get input_ids and if that key is not present get encoder_input_ids."""
return kwargs.get('input_ids', kwargs.get('encoder_input_ids', None))
# Encoder and Decoder


Expand Down Expand Up @@ -481,33 +477,6 @@ def max_positions(self):
"""Maximum input length supported by the encoder."""
return min(self.max_source_positions, self.embed_positions.max_positions)

# Unused
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to *new_order*.

Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order

Returns:
*encoder_out* rearranged according to *new_order*
"""
if encoder_out.encoder_out is not None:
encoder_out = encoder_out._replace(encoder_out=encoder_out.encoder_out.index_select(1, new_order))
if encoder_out.encoder_padding_mask is not None:
encoder_out = encoder_out._replace(
encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order)
)
if encoder_out.encoder_embedding is not None:
encoder_out = encoder_out._replace(
encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order)
)
if encoder_out.encoder_states is not None:
for idx, state in enumerate(encoder_out.encoder_states):
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
return encoder_out


class BartDecoder(nn.Module):
"""
Expand Down Expand Up @@ -644,7 +613,7 @@ def buffered_future_mask(self, tensor):
# Helper Modules


class BARTClassificationHead(nn.Module):
class BartClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""

# This can trivially be shared with RobertaClassificationHead
Expand Down Expand Up @@ -920,16 +889,6 @@ def _append_prev_key_padding_mask(
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask

def reorder_incremental_state(self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order):
"""Reorder buffered internal state (for incremental generation)."""
# TODO(SS): Where is this used?
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
if input_buffer[k] is not None:
input_buffer[k] = input_buffer[k].index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer)

def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
Expand Down
38 changes: 2 additions & 36 deletions src/transformers/modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,42 +236,6 @@ def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):

return decoder_outputs + encoder_outputs

@staticmethod
def prepare_model_kwargs(**kwargs):
""" Prepare the encoder and decoder's keyword arguments.

Keyword arguments come in 3 flavors:
- encoder-specific (prefixed by `encoder_`)
- decoder-specific (prefixed by `decoder_`)
- those that apply to the model as whole.

We let the specific kwargs override the common ones in case of
conflict.
"""
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
}
decoder_kwargs = kwargs_common.copy()
encoder_kwargs = kwargs_common.copy()
encoder_kwargs.update(
{
argument[len("encoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("encoder_")
}
)
decoder_kwargs.update(
{
argument[len("decoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
)
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
return encoder_kwargs, decoder_kwargs


class Model2Model(PreTrainedEncoderDecoder):
r"""
Expand Down Expand Up @@ -348,3 +312,5 @@ def from_pretrained(cls, *args, **kwargs):
kwargs["decoder_model"] = torch.nn.LSTM(kwargs.pop("decoder_config"))
model = super().from_pretrained(*args, **kwargs)
return model


51 changes: 51 additions & 0 deletions src/transformers/utils_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Classes to support Encoder-Decoder architectures """


def prepare_encoder_decoder_model_kwargs(**kwargs):
""" Prepare the encoder and decoder's keyword arguments.

Keyword arguments come in 3 flavors:
- encoder-specific (prefixed by `encoder_`)
- decoder-specific (prefixed by `decoder_`)
- those that apply to the model as whole.

We let the specific kwargs override the common ones in case of
conflict.
"""
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
}
decoder_kwargs = kwargs_common.copy()
encoder_kwargs = kwargs_common.copy()
encoder_kwargs.update(
{
argument[len("encoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("encoder_")
}
)
decoder_kwargs.update(
{
argument[len("decoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
)
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
return encoder_kwargs, decoder_kwargs
10 changes: 5 additions & 5 deletions tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def prepare_config_and_inputs_for_common(self):
return (
config,
{
"input_ids": input_ids,
#"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": input_mask,
"encoder_input_ids": input_ids,
"decoder_input_ids": input_ids, # TODO(SS): use prepare_model_kwargs llike T5
"decoder_input_ids": input_ids,
"decoder_lm_labels": decoder_lm_labels,
},
)
Expand Down Expand Up @@ -213,13 +213,13 @@ def test_forward(self):
max_position_embeddings=48,
)
model = BartForSequenceClassification(config)
outputs = model(input_ids)
outputs = model(input_ids=input_ids)
logits = outputs[0]
expected_shape = torch.Size((self.batch_size, config.num_labels))
self.assertEqual(logits.shape, expected_shape)

lm_model = BartForMaskedLM(config)
output = lm_model(input_ids)[0]
output = lm_model(input_ids=input_ids)[0]
expected_shape = (self.batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(output.shape, expected_shape)

Expand All @@ -230,7 +230,7 @@ def test_inference_no_head(self):
model = BartModel.from_pretrained("bart-large")
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
with torch.no_grad():
output = model(input_ids)[0]
output = model(input_ids=input_ids)[0]
expected_shape = torch.Size((1, 11, 1024))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.Tensor(
Expand Down