diff --git a/whisper/decoding.py b/whisper/decoding.py index eaedf70d8..c604631c3 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -94,7 +94,7 @@ class DecodingOptions: # timestamp sampling options without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only - max_initial_timestamp: Optional[float] = 0.0 # the initial timestamp cannot be later than this + max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this # implementation details fp16: bool = True # use fp16 for most of the calculation diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 262361391..f97029989 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -92,41 +92,37 @@ def transcribe( if verbose is not None: print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") - mel = mel.unsqueeze(0) language = decode_options["language"] task = decode_options.get("task", "transcribe") tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) - def decode_with_fallback(segment: torch.Tensor) -> List[DecodingResult]: + def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature - kwargs = {**decode_options} - t = temperatures[0] - if t == 0: - best_of = kwargs.pop("best_of", None) - else: - best_of = kwargs.get("best_of", None) - - options = DecodingOptions(**kwargs, temperature=t) - results = model.decode(segment, options) - - kwargs.pop("beam_size", None) # no beam search for t > 0 - kwargs.pop("patience", None) # no patience for t > 0 - kwargs["best_of"] = best_of # enable best_of for t > 0 - for t in temperatures[1:]: - needs_fallback = [ - compression_ratio_threshold is not None - and result.compression_ratio > compression_ratio_threshold - or logprob_threshold is not None - and result.avg_logprob < logprob_threshold - for result in results - ] - if any(needs_fallback): - options = DecodingOptions(**kwargs, temperature=t) - retries = model.decode(segment[needs_fallback], options) - for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]): - results[original_index] = retries[retry_index] - - return results + decode_result = None + + for t in temperatures: + kwargs = {**decode_options} + if t > 0: + # disable beam_size and patience when t > 0 + kwargs.pop("beam_size", None) + kwargs.pop("patience", None) + else: + # disable best_of when t == 0 + kwargs.pop("best_of", None) + + options = DecodingOptions(**kwargs, temperature=t) + decode_result = model.decode(segment, options) + + needs_fallback = False + if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: + needs_fallback = True # too repetitive + if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: + needs_fallback = True # average log probability is too low + + if not needs_fallback: + break + + return decode_result seek = 0 input_stride = exact_div( @@ -175,11 +171,11 @@ def add_segment( with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: while seek < num_frames: timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype) + segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE decode_options["prompt"] = all_tokens[prompt_reset_since:] - result = decode_with_fallback(segment)[0] + result: DecodingResult = decode_with_fallback(segment) tokens = torch.tensor(result.tokens) if no_speech_threshold is not None: