Skip to content

Commit

Permalink
Merge pull request #199 from shahrukhx01/master
Browse files Browse the repository at this point in the history
fix typo in transcribe.py
  • Loading branch information
Jeronymous authored Jul 22, 2024
2 parents a82e4d8 + 025c510 commit 932f8f6
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def transcribe_timestamped(
refine_whisper_precision=0.5,
min_word_duration=0.02, # Was 0.04 before 1.11
plot_word_alignment=False,
word_alignement_most_top_layers=None, # Was 6 before 1.9
word_alignment_most_top_layers=None, # Was 6 before 1.9
remove_empty_words=False,
use_backend_timestamps=False,

Expand Down Expand Up @@ -214,7 +214,7 @@ def transcribe_timestamped(
assert refine_whisper_precision >= 0 and refine_whisper_precision / AUDIO_TIME_PER_TOKEN == round(refine_whisper_precision / AUDIO_TIME_PER_TOKEN), f"refine_whisper_precision must be a positive multiple of {AUDIO_TIME_PER_TOKEN}"
refine_whisper_precision_nframes = round(refine_whisper_precision / AUDIO_TIME_PER_TOKEN)
assert min_word_duration >= 0, f"min_word_duration must be a positive number"
assert word_alignement_most_top_layers is None or word_alignement_most_top_layers > 0, f"word_alignement_most_top_layers must be a strictly positive number"
assert word_alignment_most_top_layers is None or word_alignment_most_top_layers > 0, f"word_alignment_most_top_layers must be a strictly positive number"

if isinstance(temperature, (list, tuple)) and len(temperature) == 1:
temperature = temperature[0]
Expand Down Expand Up @@ -242,9 +242,9 @@ def transcribe_timestamped(
time_precision = input_stride * HOP_LENGTH / SAMPLE_RATE
assert time_precision == AUDIO_TIME_PER_TOKEN

alignment_heads = get_alignment_heads(model) if word_alignement_most_top_layers is None else None
if alignment_heads is None and word_alignement_most_top_layers is None:
word_alignement_most_top_layers = 6
alignment_heads = get_alignment_heads(model) if word_alignment_most_top_layers is None else None
if alignment_heads is None and word_alignment_most_top_layers is None:
word_alignment_most_top_layers = 6

alignment_options = dict(
remove_punctuation_from_words=remove_punctuation_from_words,
Expand All @@ -253,7 +253,7 @@ def transcribe_timestamped(
detect_disfluencies=detect_disfluencies,
refine_whisper_precision_nframes=refine_whisper_precision_nframes,
plot_word_alignment=plot_word_alignment,
word_alignement_most_top_layers=word_alignement_most_top_layers,
word_alignment_most_top_layers=word_alignment_most_top_layers,
alignment_heads=alignment_heads,
)
whisper_options = dict(
Expand Down Expand Up @@ -351,7 +351,7 @@ def _transcribe_timestamped_efficient(
refine_whisper_precision_nframes,
alignment_heads,
plot_word_alignment,
word_alignement_most_top_layers,
word_alignment_most_top_layers,
detect_disfluencies,
trust_whisper_timestamps,
use_timestamps_for_alignment = True,
Expand All @@ -378,13 +378,13 @@ def _transcribe_timestamped_efficient(

debug = logger.getEffectiveLevel() >= logging.DEBUG

word_alignement_most_top_layers = float("inf") if word_alignement_most_top_layers is None else word_alignement_most_top_layers
word_alignment_most_top_layers = float("inf") if word_alignment_most_top_layers is None else word_alignment_most_top_layers

# The main outcome
timestamped_word_segments = [] # list of timestamped word segments that have been collected so far
# Main variables to be accumulated
segment_tokens = [[]] # list of lists of token indices that have been collected so far (one list per segment)
segment_attweights = [[] for _ in range(min(word_alignement_most_top_layers, len(model.decoder.blocks)))]
segment_attweights = [[] for _ in range(min(word_alignment_most_top_layers, len(model.decoder.blocks)))]
# attention weights on the last segments
segment_avglogprobs = [] # average log probability for each segment (actually of the corresponding chunk, as computed by whisper)
segment_logprobs = [] # token log probabilities for each segment
Expand Down Expand Up @@ -875,7 +875,7 @@ def hook_output_logits(layer, ins, outs):
nblocks = len(model.decoder.blocks)
j = 0
for i, block in enumerate(model.decoder.blocks):
if i < nblocks - word_alignement_most_top_layers:
if i < nblocks - word_alignment_most_top_layers:
continue
all_hooks.append(
block.cross_attn.register_forward_hook(
Expand Down Expand Up @@ -995,7 +995,7 @@ def _transcribe_timestamped_naive(
use_backend_timestamps,
alignment_heads,
plot_word_alignment,
word_alignement_most_top_layers,
word_alignment_most_top_layers,
detect_disfluencies,
trust_whisper_timestamps,
min_word_duration,
Expand All @@ -1006,7 +1006,7 @@ def _transcribe_timestamped_naive(
language = whisper_options["language"]
refine_whisper_precision_sec = refine_whisper_precision_nframes * AUDIO_TIME_PER_TOKEN

word_alignement_most_top_layers = float("inf") if word_alignement_most_top_layers is None else word_alignement_most_top_layers
word_alignment_most_top_layers = float("inf") if word_alignment_most_top_layers is None else word_alignment_most_top_layers

audio = get_audio_tensor(audio)
audio_duration = audio.shape[-1] / SAMPLE_RATE
Expand Down Expand Up @@ -1077,7 +1077,7 @@ def hook_output_logits(layer, ins, outs):

n_mels = model.dims.n_mels if hasattr(model.dims, "n_mels") else 80

attention_weights = [[] for _ in range(min(word_alignement_most_top_layers, len(model.decoder.blocks)))]
attention_weights = [[] for _ in range(min(word_alignment_most_top_layers, len(model.decoder.blocks)))]

try:

Expand All @@ -1087,7 +1087,7 @@ def hook_output_logits(layer, ins, outs):
nblocks = len(model.decoder.blocks)
j = 0
for i, block in enumerate(model.decoder.blocks):
if i < nblocks - word_alignement_most_top_layers:
if i < nblocks - word_alignment_most_top_layers:
continue
def hook(layer, ins, outs, index=j):
if is_transformer_model(model):
Expand Down

0 comments on commit 932f8f6

Please sign in to comment.