Skip to content

Commit

Permalink
Merge pull request #3140 from patrickvonplaten/merge_bart_generate_in…
Browse files Browse the repository at this point in the history
…to_default_generate

Merge bart generate into default generate
  • Loading branch information
thomwolf authored Mar 11, 2020
2 parents d6de642 + ac303ea commit db29ffc
Show file tree
Hide file tree
Showing 9 changed files with 476 additions and 410 deletions.
2 changes: 1 addition & 1 deletion examples/summarization/bart/evaluate_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
num_beams=4,
length_penalty=2.0,
max_length=140,
min_len=55,
min_length=55,
no_repeat_ngram_size=3,
)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def __init__(
self,
activation_dropout=0.0,
vocab_size=50265,
bos_token_id=0,
pad_token_id=1,
eos_token_id=2,
eos_token_ids=[2],
d_model=1024,
encoder_ffn_dim=4096,
encoder_layers=12,
Expand All @@ -58,7 +59,7 @@ def __init__(
classifier_dropout=0.0,
output_past=False,
num_labels=3,
bos_token_id=0,
is_encoder_decoder=True,
**common_kwargs
):
r"""
Expand All @@ -72,11 +73,12 @@ def __init__(
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_ids=eos_token_ids,
is_encoder_decoder=is_encoder_decoder,
**common_kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
self.eos_token_id = eos_token_id
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = self.num_hidden_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/configuration_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ def __init__(
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
is_encoder_decoder=True,
**kwargs
):
super().__init__(**kwargs)
super().__init__(
is_encoder_decoder=is_encoder_decoder, **kwargs,
)
self.vocab_size = vocab_size
self.n_positions = n_positions
self.d_model = d_model
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ def __init__(self, **kwargs):
self.pruned_heads = kwargs.pop("pruned_heads", {})

# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False)

# Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20)
self.min_length = kwargs.pop("min_length", 0)
self.do_sample = kwargs.pop("do_sample", False)
self.early_stopping = kwargs.pop("early_stopping", False)
self.num_beams = kwargs.pop("num_beams", 1)
Expand All @@ -80,6 +82,7 @@ def __init__(self, **kwargs):
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_ids = kwargs.pop("eos_token_ids", None)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)

# Fine-tuning task arguments
Expand Down
289 changes: 16 additions & 273 deletions src/transformers/modeling_bart.py

Large diffs are not rendered by default.

Loading

0 comments on commit db29ffc

Please sign in to comment.