Skip to content

Commit bf0ffe3

Browse files
authored
[Tests] Diverse Whisper fixes (#33665)
* fix beam indices in token_timestamps * fix attention_mask in FA2 * correct translation example with the right example * correct how somes tests are using outputs + correct num_frames * fix shortform batch prev cond tests * make fix-copies * make fix-copies * take care of shifting beam indices * [run-slow] whisper * [run-slow] whisper
1 parent ab97a78 commit bf0ffe3

File tree

4 files changed

+33
-18
lines changed

4 files changed

+33
-18
lines changed

src/transformers/models/qwen2_audio/modeling_qwen2_audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def forward(
291291

292292
causal_mask = attention_mask
293293
if attention_mask is not None: # no matter the length, we just slice it
294-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
294+
causal_mask = attention_mask[:, : key_states.shape[-2]]
295295

296296
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
297297
# therefore the input hidden states gets silently casted in float32. Hence, we need

src/transformers/models/whisper/generation_whisper.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ def _pad_to_max_length(
173173

174174

175175
class WhisperGenerationMixin(GenerationMixin):
176-
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
176+
def _extract_token_timestamps(
177+
self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None, num_input_ids=None
178+
):
177179
"""
178180
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
179181
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
@@ -200,11 +202,18 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
200202
# since the beam search strategy chooses the most probable sequences at the end of the search.
201203
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
202204
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
205+
weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
206+
207+
# beam search takes `decoder_input_ids` into account in the `beam_indices` length
208+
# but forgot to shift the beam_indices by the number of `decoder_input_ids`
209+
beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length])
210+
# we actually shif the beam indices here
211+
beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids]
212+
203213
weights = weights[:, :, :weight_length]
204214

205215
# If beam index is still -1, it means that the associated token id is EOS
206216
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
207-
beam_indices = generate_outputs.beam_indices[:, :weight_length]
208217
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
209218

210219
# Select the cross attention from the right beam for each output sequences
@@ -218,8 +227,10 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
218227

219228
# make sure timestamps are as long as weights
220229
input_length = weight_length or cross_attentions[0].shape[2]
221-
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1]
222-
batch_size = timestamps.shape[0]
230+
batch_size = generate_outputs.sequences.shape[0]
231+
timestamps = torch.zeros(
232+
(batch_size, input_length + 1), dtype=torch.float32, device=generate_outputs.sequences.device
233+
)
223234

224235
if num_frames is not None:
225236
# two cases:
@@ -239,6 +250,7 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
239250
else:
240251
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
241252
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
253+
num_frames = num_frames.cpu() if isinstance(num_frames, (torch.Tensor)) else num_frames
242254
num_frames = np.repeat(num_frames, repeat_time)
243255

244256
if num_frames is None or isinstance(num_frames, int):
@@ -948,7 +960,10 @@ def _postprocess_outputs(
948960
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
949961
num_frames = getattr(generation_config, "num_frames", None)
950962
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
951-
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
963+
seek_outputs,
964+
generation_config.alignment_heads,
965+
num_frames=num_frames,
966+
num_input_ids=decoder_input_ids.shape[-1],
952967
)
953968
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
954969

src/transformers/models/whisper/modeling_whisper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def forward(
422422

423423
causal_mask = attention_mask
424424
if attention_mask is not None: # no matter the length, we just slice it
425-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
425+
causal_mask = attention_mask[:, : key_states.shape[-2]]
426426

427427
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
428428
# therefore the input hidden states gets silently casted in float32. Hence, we need

tests/models/whisper/test_modeling_whisper.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,14 +1916,14 @@ def test_large_generation_multilingual(self):
19161916
input_features, do_sample=False, max_length=20, language="<|de|>", task="transcribe"
19171917
)
19181918
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1919-
EXPECTED_TRANSCRIPT = " Mein sechster Sohn scheint, wenigstens auf den ersten Blick,"
1919+
EXPECTED_TRANSCRIPT = " Denken Sie, soeben walten meine Gedanken bei Ihnen in Adela"
19201920
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
19211921

19221922
generated_ids = model.generate(
19231923
input_features, do_sample=False, max_length=20, language="<|de|>", task="translate"
19241924
)
19251925
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1926-
EXPECTED_TRANSCRIPT = " My sixth son seems, at least at first glance, the most deeply-minded"
1926+
EXPECTED_TRANSCRIPT = " Think, my thoughts were just rolling with you in Adelaide, and I"
19271927
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
19281928

19291929
@slow
@@ -2238,7 +2238,7 @@ def test_tiny_token_timestamp_generation(self):
22382238
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
22392239
)
22402240

2241-
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
2241+
self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape)
22422242

22432243
# fmt: off
22442244
EXPECTED_OUTPUT = torch.tensor([
@@ -2249,7 +2249,7 @@ def test_tiny_token_timestamp_generation(self):
22492249
])
22502250
# fmt: on
22512251

2252-
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
2252+
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
22532253

22542254
@slow
22552255
def test_large_token_timestamp_generation(self):
@@ -2268,7 +2268,7 @@ def test_large_token_timestamp_generation(self):
22682268
**input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
22692269
)
22702270

2271-
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
2271+
self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape)
22722272

22732273
# fmt: off
22742274
EXPECTED_OUTPUT = torch.tensor([
@@ -2279,7 +2279,7 @@ def test_large_token_timestamp_generation(self):
22792279
])
22802280
# fmt: on
22812281

2282-
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
2282+
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
22832283

22842284
@slow
22852285
def test_tiny_token_timestamp_batch_generation(self):
@@ -2306,9 +2306,9 @@ def test_tiny_token_timestamp_batch_generation(self):
23062306
)
23072307

23082308
# task id and lang id prompts should not have timestamp tokens
2309-
self.assertEqual(generate_outputs.sequences.shape[-1] - 2, generate_outputs.token_timestamps.shape[-1])
2309+
self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1])
23102310

2311-
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
2311+
self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)
23122312

23132313
@slow
23142314
def test_tiny_token_timestamp_generation_longform(self):
@@ -2799,7 +2799,7 @@ def test_whisper_shortform_single_batch_prev_cond(self):
27992799

28002800
torch.manual_seed(0)
28012801
result = model.generate(input_features, **gen_kwargs)
2802-
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
2802+
decoded = processor.batch_decode(result, skip_special_tokens=True)
28032803

28042804
assert decoded == EXPECTED_TEXT
28052805

@@ -2814,7 +2814,7 @@ def test_whisper_shortform_single_batch_prev_cond(self):
28142814

28152815
torch.manual_seed(0)
28162816
result = model.generate(input_features, **gen_kwargs)
2817-
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
2817+
decoded = processor.batch_decode(result, skip_special_tokens=True)
28182818

28192819
assert decoded == EXPECTED_TEXT1
28202820

@@ -3114,7 +3114,7 @@ def test_whisper_shortform_multi_batch_hard_prev_cond(self):
31143114
}
31153115

31163116
result = model.generate(**inputs, **gen_kwargs)
3117-
decoded_all = processor.batch_decode(result.sequences, skip_special_tokens=True)
3117+
decoded_all = processor.batch_decode(result, skip_special_tokens=True)
31183118

31193119
for i in range(num_samples):
31203120
if isinstance(EXPECTED_TEXT[i], str):

0 commit comments

Comments
 (0)