Skip to content

Commit

Permalink
allowing nonzero initial temperature
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Sep 30, 2022
1 parent 30dc5c5 commit 7cb4cc2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 33 deletions.
2 changes: 1 addition & 1 deletion whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class DecodingOptions:

# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 0.0 # the initial timestamp cannot be later than this
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this

# implementation details
fp16: bool = True # use fp16 for most of the calculation
Expand Down
60 changes: 28 additions & 32 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,41 +92,37 @@ def transcribe(
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")

mel = mel.unsqueeze(0)
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)

def decode_with_fallback(segment: torch.Tensor) -> List[DecodingResult]:
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
kwargs = {**decode_options}
t = temperatures[0]
if t == 0:
best_of = kwargs.pop("best_of", None)
else:
best_of = kwargs.get("best_of", None)

options = DecodingOptions(**kwargs, temperature=t)
results = model.decode(segment, options)

kwargs.pop("beam_size", None) # no beam search for t > 0
kwargs.pop("patience", None) # no patience for t > 0
kwargs["best_of"] = best_of # enable best_of for t > 0
for t in temperatures[1:]:
needs_fallback = [
compression_ratio_threshold is not None
and result.compression_ratio > compression_ratio_threshold
or logprob_threshold is not None
and result.avg_logprob < logprob_threshold
for result in results
]
if any(needs_fallback):
options = DecodingOptions(**kwargs, temperature=t)
retries = model.decode(segment[needs_fallback], options)
for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
results[original_index] = retries[retry_index]

return results
decode_result = None

for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)

options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)

needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
needs_fallback = True # average log probability is too low

if not needs_fallback:
break

return decode_result

seek = 0
input_stride = exact_div(
Expand Down Expand Up @@ -175,11 +171,11 @@ def add_segment(
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE

decode_options["prompt"] = all_tokens[prompt_reset_since:]
result = decode_with_fallback(segment)[0]
result: DecodingResult = decode_with_fallback(segment)
tokens = torch.tensor(result.tokens)

if no_speech_threshold is not None:
Expand Down

0 comments on commit 7cb4cc2

Please sign in to comment.