Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def generate(

# 3. Make sure generation config is correctly set
# Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
self._set_return_outputs(
return_dict_in_generate = self._set_return_outputs(
return_dict_in_generate=return_dict_in_generate,
return_token_timestamps=return_token_timestamps,
logprob_threshold=logprob_threshold,
Expand Down Expand Up @@ -732,7 +732,7 @@ def generate(
else:
outputs = sequences

if generation_config.return_dict_in_generate:
if return_dict_in_generate and generation_config.return_dict_in_generate:
dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)

if num_return_sequences > 1:
Expand Down Expand Up @@ -1109,18 +1109,20 @@ def _maybe_warn_unused_inputs(
def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
if return_dict_in_generate is None:
return_dict_in_generate = generation_config.return_dict_in_generate
else:
generation_config.return_dict_in_generate = return_dict_in_generate

generation_config.return_token_timestamps = return_token_timestamps
if return_token_timestamps:
return_dict_in_generate = True
generation_config.return_dict_in_generate = True
generation_config.output_attentions = True
generation_config.output_scores = True

if logprob_threshold is not None:
return_dict_in_generate = True
generation_config.return_dict_in_generate = True
generation_config.output_scores = True

generation_config.return_dict_in_generate = return_dict_in_generate
return return_dict_in_generate

def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
if not is_shortform:
Expand Down
22 changes: 22 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from parameterized import parameterized

import transformers
from transformers import WhisperConfig
Expand Down Expand Up @@ -72,6 +73,7 @@
BeamSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateEncoderDecoderOutput,
PhrasalConstraint,
)
from transformers.generation.logits_process import LogitsProcessor
Expand Down Expand Up @@ -1820,6 +1822,26 @@ def test_custom_4d_attention_mask(self):
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)

@parameterized.expand([(True,), (False,)])
def test_generate_output_type(self, return_dict_in_generate):
expected_output_type = GenerateEncoderDecoderOutput if return_dict_in_generate else torch.Tensor
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs()
model = model_class(config).to(torch_device).eval()

# short-form generation without fallback
pred_ids = model.generate(**inputs, return_dict_in_generate=return_dict_in_generate)
assert isinstance(pred_ids, expected_output_type)

# short-form generation with fallback
pred_ids = model.generate(
**inputs,
logprob_threshold=-1.0,
temperature=[0.0, 0.1],
return_dict_in_generate=return_dict_in_generate,
)
assert isinstance(pred_ids, expected_output_type)


@require_torch
@require_torchaudio
Expand Down