Skip to content

Commit 226dad1

Browse files
committed
major refactor (reuse-bart)
1 parent 38cc9c1 commit 226dad1

File tree

2 files changed

+23
-169
lines changed

2 files changed

+23
-169
lines changed

src/transformers/modeling_fsmt.py

Lines changed: 20 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,14 @@
4646
add_start_docstrings_to_callable,
4747
replace_return_docstrings,
4848
)
49-
from .modeling_bart import DecoderLayer, EncoderLayer
49+
from .modeling_bart import (
50+
DecoderLayer,
51+
EncoderLayer,
52+
LayerNorm,
53+
_prepare_bart_decoder_inputs,
54+
_reorder_buffer,
55+
invert_mask,
56+
)
5057
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput
5158
from .modeling_utils import PreTrainedModel
5259

@@ -97,43 +104,36 @@
97104
98105
Here is how to compare BLEU scores against fairseq implementation:
99106
100-
# Note: to match fairseq params you need to set num_beams=50 in
101-
# `configuration_fsmt.py` and lower BS as it'll need more GPU memory
102-
103-
cd examples/seq2seq
104-
105107
# en-ru
106108
107109
export PAIR=en-ru
108110
export DATA_DIR=data/$PAIR
109111
export SAVE_DIR=data/$PAIR
110112
export BS=8
113+
export NUM_BEAMS=50
111114
mkdir -p $DATA_DIR
112115
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
113116
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
114117
echo $PAIR
115-
PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation
118+
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
116119
117120
# (fairseq BLEU: 36.4 http://matrix.statmt.org/matrix/output/1914?score_id=37605)
118121
119122
120-
121-
122123
# ru-en
123124
124125
export PAIR=ru-en
125126
export DATA_DIR=data/$PAIR
126127
export SAVE_DIR=data/$PAIR
127128
export BS=8
129+
export NUM_BEAMS=50
128130
mkdir -p $DATA_DIR
129131
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
130132
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
131-
echo $PAIR
132-
PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation
133-
134-
# (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937)
133+
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
135134
136135
136+
# (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937)
137137
138138
139139
# de-en
@@ -142,11 +142,12 @@
142142
export DATA_DIR=data/$PAIR
143143
export SAVE_DIR=data/$PAIR
144144
export BS=8
145+
export NUM_BEAMS=50
145146
mkdir -p $DATA_DIR
146147
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
147148
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
148149
echo $PAIR
149-
PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation
150+
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
150151
151152
# (fairseq BLEU: 42.3 http://matrix.statmt.org/matrix/output/1902?run_id=6750)
152153
@@ -162,7 +163,7 @@
162163
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
163164
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
164165
echo $PAIR
165-
PYTHONPATH="../../src" python run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation
166+
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py stas/fsmt-wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
166167
167168
# (fairseq BLEU: 43.1 http://matrix.statmt.org/matrix/output/1909?run_id=6862)
168169
@@ -171,8 +172,7 @@
171172

172173
FSMT_START_DOCSTRING = r"""
173174
174-
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
175-
refer to the PyTorch documentation for all matters related to general usage and behavior.
175+
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and behavior.
176176
177177
Parameters:
178178
config (:class:`~transformers.FSMTConfig`): Model configuration class with all the parameters of the model.
@@ -238,33 +238,6 @@
238238
"""
239239

240240

241-
def invert_mask(attention_mask):
242-
"""Turns 1->0, 0->1, False->True, True-> False"""
243-
assert attention_mask.dim() == 2
244-
return attention_mask.eq(0)
245-
246-
247-
def _prepare_fsmt_decoder_inputs(
248-
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
249-
):
250-
"""Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
251-
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
252-
Note: this is not called during generation
253-
"""
254-
pad_token_id = config.pad_token_id
255-
if decoder_input_ids is None:
256-
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
257-
bsz, tgt_len = decoder_input_ids.size()
258-
if decoder_padding_mask is None:
259-
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
260-
else:
261-
decoder_padding_mask = invert_mask(decoder_padding_mask)
262-
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
263-
dtype=causal_mask_dtype, device=decoder_input_ids.device
264-
)
265-
return decoder_input_ids, decoder_padding_mask, causal_mask
266-
267-
268241
class PretrainedFSMTModel(PreTrainedModel):
269242
config_class = FSMTConfig
270243
base_model_prefix = "model"
@@ -293,36 +266,6 @@ def dummy_inputs(self):
293266
return dummy_inputs
294267

295268

296-
def _make_linear_from_emb(emb):
297-
vocab_size, emb_size = emb.weight.shape
298-
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
299-
lin_layer.weight.data = emb.weight.data
300-
return lin_layer
301-
302-
303-
# Helper Functions, mostly for making masks
304-
def _check_shapes(shape_1, shape2):
305-
if shape_1 != shape2:
306-
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
307-
308-
309-
def shift_tokens_right(input_ids, pad_token_id):
310-
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
311-
prev_output_tokens = input_ids.clone()
312-
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
313-
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
314-
prev_output_tokens[:, 1:] = input_ids[:, :-1]
315-
return prev_output_tokens
316-
317-
318-
def make_padding_mask(input_ids, padding_idx=1):
319-
"""True for pad tokens"""
320-
padding_mask = input_ids.eq(padding_idx)
321-
if not padding_mask.any():
322-
padding_mask = None
323-
return padding_mask
324-
325-
326269
# Helper Modules
327270

328271

@@ -592,70 +535,11 @@ def forward(
592535
)
593536

594537

595-
def _reorder_buffer(attn_cache, new_order):
596-
for k, input_buffer_k in attn_cache.items():
597-
if input_buffer_k is not None:
598-
attn_cache[k] = input_buffer_k.index_select(0, new_order)
599-
return attn_cache
600-
601-
602-
# XXX: remove this and its references
603-
class LearnedPositionalEmbedding(nn.Embedding):
604-
"""
605-
This module learns positional embeddings up to a fixed maximum size.
606-
Padding ids are ignored by either offsetting based on padding_idx
607-
or by setting padding_idx to None and ensuring that the appropriate
608-
position ids are passed to the forward function.
609-
"""
610-
611-
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
612-
# FSMT is set up so that if padding_idx is specified then offset the embedding ids by 2
613-
# and adjust num_embeddings appropriately. Other models dont have this hack
614-
self.offset = offset
615-
assert padding_idx is not None
616-
num_embeddings += offset
617-
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
618-
619-
def forward(self, input_ids, use_cache=False):
620-
"""Input is expected to be of size [bsz x seqlen]."""
621-
bsz, seq_len = input_ids.shape[:2]
622-
if use_cache:
623-
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
624-
else:
625-
# starts at 0, ends at 1-seq_len
626-
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
627-
return super().forward(positions + self.offset)
628-
629-
630-
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
631-
if torch.cuda.is_available():
632-
try:
633-
from apex.normalization import FusedLayerNorm
634-
635-
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
636-
except ImportError:
637-
pass
638-
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
639-
640-
641-
def fill_with_neg_inf(t):
642-
"""FP16-compatible function that fills a input_ids with -inf."""
643-
return t.float().fill_(float("-inf")).type_as(t)
644-
645-
646538
# Public API
647539
def _get_shape(t):
648540
return getattr(t, "shape", None)
649541

650542

651-
# def output_projection(self):
652-
# return nn.Linear(
653-
# self.embed_tokens.weight.shape[1],
654-
# self.embed_tokens.weight.shape[0],
655-
# bias=False,
656-
# )
657-
658-
659543
@add_start_docstrings(
660544
"The bare FSMT Model outputting raw hidden-states without any specific head on top.",
661545
FSMT_START_DOCSTRING,
@@ -713,7 +597,7 @@ def forward(
713597

714598
# make masks if user doesn't supply
715599
if not use_cache:
716-
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_fsmt_decoder_inputs(
600+
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
717601
self.config,
718602
input_ids,
719603
decoder_input_ids=decoder_input_ids,
@@ -772,13 +656,13 @@ def get_input_embeddings(self):
772656
return self.encoder.embed_tokens
773657

774658
def set_input_embeddings(self, value):
775-
self.encoder.embed_tokens = value # self.encoder_embed_tokens = value
659+
self.encoder.embed_tokens = value
776660

777661
def get_output_embeddings(self):
778662
return self.decoder.embed_tokens
779663

780664
def set_output_embeddings(self, value):
781-
self.decoder.embed_tokens = value # self.decoder_embed_tokens = value
665+
self.decoder.embed_tokens = value
782666

783667

784668
@add_start_docstrings(
@@ -935,8 +819,6 @@ def get_encoder(self):
935819

936820
def get_output_embeddings(self):
937821
return self.model.decoder.embed_tokens
938-
# XXX: it was, but probably is not needed here
939-
# return _make_linear_from_emb(self.decoder.embed_tokens) # make it on the fly
940822

941823

942824
def make_positions(tensor, padding_idx: int):

tests/test_modeling_fsmt.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,8 @@
3030
import torch
3131

3232
from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer
33-
from transformers.modeling_fsmt import (
34-
SinusoidalPositionalEmbedding,
35-
_prepare_fsmt_decoder_inputs,
36-
invert_mask,
37-
shift_tokens_right,
38-
)
39-
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
33+
from transformers.modeling_bart import _prepare_bart_decoder_inputs, invert_mask
34+
from transformers.modeling_fsmt import SinusoidalPositionalEmbedding
4035

4136

4237
@require_torch
@@ -164,7 +159,7 @@ def test_advanced_inputs(self):
164159
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
165160
config.use_cache = False
166161
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
167-
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
162+
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
168163
config, inputs_dict["input_ids"]
169164
)
170165
model = FSMTModel(config).to(torch_device).eval()
@@ -287,15 +282,6 @@ def test_generate_beam_search(self):
287282
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
288283
# TODO(SS): uneven length batches, empty inputs
289284

290-
def test_shift_tokens_right(self):
291-
input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long()
292-
shifted = shift_tokens_right(input_ids, 1)
293-
n_pad_before = input_ids.eq(1).float().sum()
294-
n_pad_after = shifted.eq(1).float().sum()
295-
self.assertEqual(shifted.shape, input_ids.shape)
296-
self.assertEqual(n_pad_after, n_pad_before - 1)
297-
self.assertTrue(torch.eq(shifted[:, 0], 2).all())
298-
299285
def test_generate_fp16(self):
300286
config, input_ids, batch_size = self._get_config_and_data()
301287
attention_mask = input_ids.ne(1).to(torch_device)
@@ -310,20 +296,6 @@ def test_dummy_inputs(self):
310296
model = FSMTForConditionalGeneration(config).eval().to(torch_device)
311297
model(**model.dummy_inputs)
312298

313-
def test_prepare_fsmt_decoder_inputs(self):
314-
config, *_ = self._get_config_and_data()
315-
input_ids = _long_tensor(([4, 4, 2]))
316-
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
317-
ignore = float("-inf")
318-
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
319-
config, input_ids, decoder_input_ids
320-
)
321-
expected_causal_mask = torch.tensor(
322-
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
323-
).to(input_ids.device)
324-
self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
325-
self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())
326-
327299
def test_resize_tokens_embeddings_more(self):
328300
config, input_ids, _ = self._get_config_and_data()
329301

0 commit comments

Comments
 (0)