Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge bart generate into default generate #3140

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d8e2b3c
fix conflicts
patrickvonplaten Mar 6, 2020
c0d9dd3
refactored code a bit and made more generic
patrickvonplaten Mar 5, 2020
ff64822
fix conflicts
patrickvonplaten Mar 6, 2020
7cba11f
better naming
patrickvonplaten Mar 5, 2020
aceb3fb
only do output_past=True for language generation in bart
patrickvonplaten Mar 5, 2020
5b3000d
renamed min_len to min_length
patrickvonplaten Mar 5, 2020
7a11e92
work in progress
patrickvonplaten Mar 6, 2020
4212169
comment out stuff
patrickvonplaten Mar 6, 2020
333affc
add current changes
patrickvonplaten Mar 6, 2020
77e6775
add current changes
patrickvonplaten Mar 6, 2020
c62444d
fix conflicts
patrickvonplaten Mar 8, 2020
2acfe63
best current version and make style
patrickvonplaten Mar 6, 2020
d880a5f
finalized PR
patrickvonplaten Mar 7, 2020
629aac9
do not allow do_sample and weird force bos token things
patrickvonplaten Mar 7, 2020
a5751f7
fix bug with attention_mask as optional input argument
patrickvonplaten Mar 8, 2020
41b437e
add draft version of propsoed changes for ROGUE score
patrickvonplaten Mar 8, 2020
ca2047b
refactor variable naming and improve tf generate in line with torch g…
patrickvonplaten Mar 9, 2020
a2c8e51
fix torch to tf translation
patrickvonplaten Mar 9, 2020
374deef
fixed typo
patrickvonplaten Mar 9, 2020
cf06290
remove ipdb
patrickvonplaten Mar 9, 2020
1098971
rename variable
patrickvonplaten Mar 9, 2020
ca1330f
do not mess with the negative sign
patrickvonplaten Mar 10, 2020
9b8ee8c
delete print and make style
patrickvonplaten Mar 10, 2020
7351a8d
re-add scoring filtering
patrickvonplaten Mar 10, 2020
d997ac7
fix typo
patrickvonplaten Mar 10, 2020
1ba21f9
fix bug in tf no_repeat_ngram_size
patrickvonplaten Mar 10, 2020
a332cc9
finalize generation merge
patrickvonplaten Mar 11, 2020
bc9d5d9
make all tensors half precision
patrickvonplaten Mar 11, 2020
ac303ea
fix problem with half
patrickvonplaten Mar 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/summarization/bart/evaluate_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
num_beams=4,
length_penalty=2.0,
max_length=140,
min_len=55,
min_length=55,
no_repeat_ngram_size=3,
)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def __init__(
self,
activation_dropout=0.0,
vocab_size=50265,
bos_token_id=0,
pad_token_id=1,
eos_token_id=2,
eos_token_ids=[2],
d_model=1024,
encoder_ffn_dim=4096,
encoder_layers=12,
Expand All @@ -58,7 +59,7 @@ def __init__(
classifier_dropout=0.0,
output_past=False,
num_labels=3,
bos_token_id=0,
is_encoder_decoder=True,
**common_kwargs
):
r"""
Expand All @@ -72,11 +73,12 @@ def __init__(
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_ids=eos_token_ids,
is_encoder_decoder=is_encoder_decoder,
**common_kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
self.eos_token_id = eos_token_id
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = self.num_hidden_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/configuration_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ def __init__(
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
is_encoder_decoder=True,
**kwargs
):
super().__init__(**kwargs)
super().__init__(
is_encoder_decoder=is_encoder_decoder, **kwargs,
)
self.vocab_size = vocab_size
self.n_positions = n_positions
self.d_model = d_model
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ def __init__(self, **kwargs):
self.pruned_heads = kwargs.pop("pruned_heads", {})

# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False)

# Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20)
self.min_length = kwargs.pop("min_length", 0)
self.do_sample = kwargs.pop("do_sample", False)
self.early_stopping = kwargs.pop("early_stopping", False)
self.num_beams = kwargs.pop("num_beams", 1)
Expand All @@ -80,6 +82,7 @@ def __init__(self, **kwargs):
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_ids = kwargs.pop("eos_token_ids", None)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)

# Fine-tuning task arguments
Expand Down
289 changes: 16 additions & 273 deletions src/transformers/modeling_bart.py

Large diffs are not rendered by default.

Loading