|
46 | 46 | add_start_docstrings_to_callable,
|
47 | 47 | replace_return_docstrings,
|
48 | 48 | )
|
49 |
| -from .modeling_bart import DecoderLayer, EncoderLayer |
| 49 | +from .modeling_bart import ( |
| 50 | + DecoderLayer, |
| 51 | + EncoderLayer, |
| 52 | + LayerNorm, |
| 53 | + _prepare_bart_decoder_inputs, |
| 54 | + _reorder_buffer, |
| 55 | + invert_mask, |
| 56 | +) |
50 | 57 | from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput
|
51 | 58 | from .modeling_utils import PreTrainedModel
|
52 | 59 |
|
|
97 | 104 |
|
98 | 105 | Here is how to compare BLEU scores against fairseq implementation:
|
99 | 106 |
|
100 |
| -# Note: to match fairseq params you need to set num_beams=50 in |
101 |
| -# `configuration_fsmt.py` and lower BS as it'll need more GPU memory |
102 |
| -
|
103 |
| -cd examples/seq2seq |
104 |
| -
|
105 | 107 | # en-ru
|
106 | 108 |
|
107 | 109 | export PAIR=en-ru
|
108 | 110 | export DATA_DIR=data/$PAIR
|
109 | 111 | export SAVE_DIR=data/$PAIR
|
110 | 112 | export BS=8
|
| 113 | +export NUM_BEAMS=50 |
111 | 114 | mkdir -p $DATA_DIR
|
112 | 115 | sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
113 | 116 | sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
114 | 117 | echo $PAIR
|
115 |
| -PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation |
| 118 | +PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS |
116 | 119 |
|
117 | 120 | # (fairseq BLEU: 36.4 http://matrix.statmt.org/matrix/output/1914?score_id=37605)
|
118 | 121 |
|
119 | 122 |
|
120 |
| -
|
121 |
| -
|
122 | 123 | # ru-en
|
123 | 124 |
|
124 | 125 | export PAIR=ru-en
|
125 | 126 | export DATA_DIR=data/$PAIR
|
126 | 127 | export SAVE_DIR=data/$PAIR
|
127 | 128 | export BS=8
|
| 129 | +export NUM_BEAMS=50 |
128 | 130 | mkdir -p $DATA_DIR
|
129 | 131 | sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
130 | 132 | sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
131 |
| -echo $PAIR |
132 |
| -PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation |
133 |
| -
|
134 |
| -# (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937) |
| 133 | +PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS |
135 | 134 |
|
136 | 135 |
|
| 136 | +# (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937) |
137 | 137 |
|
138 | 138 |
|
139 | 139 | # de-en
|
|
142 | 142 | export DATA_DIR=data/$PAIR
|
143 | 143 | export SAVE_DIR=data/$PAIR
|
144 | 144 | export BS=8
|
| 145 | +export NUM_BEAMS=50 |
145 | 146 | mkdir -p $DATA_DIR
|
146 | 147 | sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
147 | 148 | sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
148 | 149 | echo $PAIR
|
149 |
| -PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation |
| 150 | +PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS |
150 | 151 |
|
151 | 152 | # (fairseq BLEU: 42.3 http://matrix.statmt.org/matrix/output/1902?run_id=6750)
|
152 | 153 |
|
|
162 | 163 | sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
163 | 164 | sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
164 | 165 | echo $PAIR
|
165 |
| -PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation |
| 166 | +PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS |
166 | 167 |
|
167 | 168 | # (fairseq BLEU: 43.1 http://matrix.statmt.org/matrix/output/1909?run_id=6862)
|
168 | 169 |
|
|
171 | 172 |
|
172 | 173 | FSMT_START_DOCSTRING = r"""
|
173 | 174 |
|
174 |
| - This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and |
175 |
| - refer to the PyTorch documentation for all matters related to general usage and behavior. |
| 175 | + This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and behavior. |
176 | 176 |
|
177 | 177 | Parameters:
|
178 | 178 | config (:class:`~transformers.FSMTConfig`): Model configuration class with all the parameters of the model.
|
|
238 | 238 | """
|
239 | 239 |
|
240 | 240 |
|
241 |
| -def invert_mask(attention_mask): |
242 |
| - """Turns 1->0, 0->1, False->True, True-> False""" |
243 |
| - assert attention_mask.dim() == 2 |
244 |
| - return attention_mask.eq(0) |
245 |
| - |
246 |
| - |
247 |
| -def _prepare_fsmt_decoder_inputs( |
248 |
| - config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32 |
249 |
| -): |
250 |
| - """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if |
251 |
| - none are provided. This mimics the default behavior in fairseq. To override it pass in masks. |
252 |
| - Note: this is not called during generation |
253 |
| - """ |
254 |
| - pad_token_id = config.pad_token_id |
255 |
| - if decoder_input_ids is None: |
256 |
| - decoder_input_ids = shift_tokens_right(input_ids, pad_token_id) |
257 |
| - bsz, tgt_len = decoder_input_ids.size() |
258 |
| - if decoder_padding_mask is None: |
259 |
| - decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) |
260 |
| - else: |
261 |
| - decoder_padding_mask = invert_mask(decoder_padding_mask) |
262 |
| - causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( |
263 |
| - dtype=causal_mask_dtype, device=decoder_input_ids.device |
264 |
| - ) |
265 |
| - return decoder_input_ids, decoder_padding_mask, causal_mask |
266 |
| - |
267 |
| - |
268 | 241 | class PretrainedFSMTModel(PreTrainedModel):
|
269 | 242 | config_class = FSMTConfig
|
270 | 243 | base_model_prefix = "model"
|
@@ -293,36 +266,6 @@ def dummy_inputs(self):
|
293 | 266 | return dummy_inputs
|
294 | 267 |
|
295 | 268 |
|
296 |
| -def _make_linear_from_emb(emb): |
297 |
| - vocab_size, emb_size = emb.weight.shape |
298 |
| - lin_layer = nn.Linear(vocab_size, emb_size, bias=False) |
299 |
| - lin_layer.weight.data = emb.weight.data |
300 |
| - return lin_layer |
301 |
| - |
302 |
| - |
303 |
| -# Helper Functions, mostly for making masks |
304 |
| -def _check_shapes(shape_1, shape2): |
305 |
| - if shape_1 != shape2: |
306 |
| - raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2)) |
307 |
| - |
308 |
| - |
309 |
| -def shift_tokens_right(input_ids, pad_token_id): |
310 |
| - """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).""" |
311 |
| - prev_output_tokens = input_ids.clone() |
312 |
| - index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) |
313 |
| - prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze() |
314 |
| - prev_output_tokens[:, 1:] = input_ids[:, :-1] |
315 |
| - return prev_output_tokens |
316 |
| - |
317 |
| - |
318 |
| -def make_padding_mask(input_ids, padding_idx=1): |
319 |
| - """True for pad tokens""" |
320 |
| - padding_mask = input_ids.eq(padding_idx) |
321 |
| - if not padding_mask.any(): |
322 |
| - padding_mask = None |
323 |
| - return padding_mask |
324 |
| - |
325 |
| - |
326 | 269 | # Helper Modules
|
327 | 270 |
|
328 | 271 |
|
@@ -592,70 +535,11 @@ def forward(
|
592 | 535 | )
|
593 | 536 |
|
594 | 537 |
|
595 |
| -def _reorder_buffer(attn_cache, new_order): |
596 |
| - for k, input_buffer_k in attn_cache.items(): |
597 |
| - if input_buffer_k is not None: |
598 |
| - attn_cache[k] = input_buffer_k.index_select(0, new_order) |
599 |
| - return attn_cache |
600 |
| - |
601 |
| - |
602 |
| -# XXX: remove this and its references |
603 |
| -class LearnedPositionalEmbedding(nn.Embedding): |
604 |
| - """ |
605 |
| - This module learns positional embeddings up to a fixed maximum size. |
606 |
| - Padding ids are ignored by either offsetting based on padding_idx |
607 |
| - or by setting padding_idx to None and ensuring that the appropriate |
608 |
| - position ids are passed to the forward function. |
609 |
| - """ |
610 |
| - |
611 |
| - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset): |
612 |
| - # FSMT is set up so that if padding_idx is specified then offset the embedding ids by 2 |
613 |
| - # and adjust num_embeddings appropriately. Other models dont have this hack |
614 |
| - self.offset = offset |
615 |
| - assert padding_idx is not None |
616 |
| - num_embeddings += offset |
617 |
| - super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) |
618 |
| - |
619 |
| - def forward(self, input_ids, use_cache=False): |
620 |
| - """Input is expected to be of size [bsz x seqlen].""" |
621 |
| - bsz, seq_len = input_ids.shape[:2] |
622 |
| - if use_cache: |
623 |
| - positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing |
624 |
| - else: |
625 |
| - # starts at 0, ends at 1-seq_len |
626 |
| - positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device) |
627 |
| - return super().forward(positions + self.offset) |
628 |
| - |
629 |
| - |
630 |
| -def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): |
631 |
| - if torch.cuda.is_available(): |
632 |
| - try: |
633 |
| - from apex.normalization import FusedLayerNorm |
634 |
| - |
635 |
| - return FusedLayerNorm(normalized_shape, eps, elementwise_affine) |
636 |
| - except ImportError: |
637 |
| - pass |
638 |
| - return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) |
639 |
| - |
640 |
| - |
641 |
| -def fill_with_neg_inf(t): |
642 |
| - """FP16-compatible function that fills a input_ids with -inf.""" |
643 |
| - return t.float().fill_(float("-inf")).type_as(t) |
644 |
| - |
645 |
| - |
646 | 538 | # Public API
|
647 | 539 | def _get_shape(t):
|
648 | 540 | return getattr(t, "shape", None)
|
649 | 541 |
|
650 | 542 |
|
651 |
| -# def output_projection(self): |
652 |
| -# return nn.Linear( |
653 |
| -# self.embed_tokens.weight.shape[1], |
654 |
| -# self.embed_tokens.weight.shape[0], |
655 |
| -# bias=False, |
656 |
| -# ) |
657 |
| - |
658 |
| - |
659 | 543 | @add_start_docstrings(
|
660 | 544 | "The bare FSMT Model outputting raw hidden-states without any specific head on top.",
|
661 | 545 | FSMT_START_DOCSTRING,
|
@@ -713,7 +597,7 @@ def forward(
|
713 | 597 |
|
714 | 598 | # make masks if user doesn't supply
|
715 | 599 | if not use_cache:
|
716 |
| - decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_fsmt_decoder_inputs( |
| 600 | + decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( |
717 | 601 | self.config,
|
718 | 602 | input_ids,
|
719 | 603 | decoder_input_ids=decoder_input_ids,
|
@@ -772,13 +656,13 @@ def get_input_embeddings(self):
|
772 | 656 | return self.encoder.embed_tokens
|
773 | 657 |
|
774 | 658 | def set_input_embeddings(self, value):
|
775 |
| - self.encoder.embed_tokens = value # self.encoder_embed_tokens = value |
| 659 | + self.encoder.embed_tokens = value |
776 | 660 |
|
777 | 661 | def get_output_embeddings(self):
|
778 | 662 | return self.decoder.embed_tokens
|
779 | 663 |
|
780 | 664 | def set_output_embeddings(self, value):
|
781 |
| - self.decoder.embed_tokens = value # self.decoder_embed_tokens = value |
| 665 | + self.decoder.embed_tokens = value |
782 | 666 |
|
783 | 667 |
|
784 | 668 | @add_start_docstrings(
|
@@ -935,8 +819,6 @@ def get_encoder(self):
|
935 | 819 |
|
936 | 820 | def get_output_embeddings(self):
|
937 | 821 | return self.model.decoder.embed_tokens
|
938 |
| - # XXX: it was, but probably is not needed here |
939 |
| - # return _make_linear_from_emb(self.decoder.embed_tokens) # make it on the fly |
940 | 822 |
|
941 | 823 |
|
942 | 824 | def make_positions(tensor, padding_idx: int):
|
|
0 commit comments