Skip to content

Commit

Permalink
nocaptions -> nospeech to match the paper figure
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Sep 23, 2022
1 parent 6198952 commit 15ab548
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 39 deletions.
28 changes: 14 additions & 14 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class DecodingResult:
tokens: List[int] = field(default_factory=list)
text: str = ""
avg_logprob: float = np.nan
no_caption_prob: float = np.nan
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan

Expand Down Expand Up @@ -543,9 +543,9 @@ def _get_suppress_tokens(self) -> Tuple[int]:
suppress_tokens.extend(
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
)
if self.tokenizer.no_captions is not None:
# no-captions probability is collected separately
suppress_tokens.append(self.tokenizer.no_captions)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
suppress_tokens.append(self.tokenizer.no_speech)

return tuple(sorted(set(suppress_tokens)))

Expand Down Expand Up @@ -580,15 +580,15 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_caption_probs = [np.nan] * n_batch
no_speech_probs = [np.nan] * n_batch

try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)

if i == 0 and self.tokenizer.no_captions is not None: # save no_caption_probs
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_caption_probs = probs_at_sot[:, self.tokenizer.no_captions].tolist()
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()

# now we need to consider the logits at the last token only
logits = logits[:, -1]
Expand All @@ -605,7 +605,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
finally:
self.inference.cleanup_caching()

return tokens, sum_logprobs, no_caption_probs
return tokens, sum_logprobs, no_speech_probs

@torch.no_grad()
def run(self, mel: Tensor) -> List[DecodingResult]:
Expand All @@ -629,12 +629,12 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)

# call the main sampling loop
tokens, sum_logprobs, no_caption_probs = self._main_loop(audio_features, tokens)
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)

# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
no_caption_probs = no_caption_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_caption_probs) == n_audio
no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio

tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
Expand All @@ -653,7 +653,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]

fields = (texts, languages, tokens, audio_features, avg_logprobs, no_caption_probs)
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")

Expand All @@ -664,11 +664,11 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
tokens=tokens,
text=text,
avg_logprob=avg_logprob,
no_caption_prob=no_caption_prob,
no_speech_prob=no_speech_prob,
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
)
for text, language, tokens, features, avg_logprob, no_caption_prob in zip(*fields)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
]


Expand Down
6 changes: 3 additions & 3 deletions whisper/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def sot_prev(self) -> int:

@property
@lru_cache()
def no_captions(self) -> int:
return self._get_single_token_id("<|nocaptions|>")
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")

@property
@lru_cache()
Expand Down Expand Up @@ -283,7 +283,7 @@ def build_tokenizer(name: str = "gpt2"):
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nocaptions|>",
"<|nospeech|>",
"<|notimestamps|>",
]

Expand Down
32 changes: 10 additions & 22 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def transcribe(
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_captions_threshold: Optional[float] = 0.6,
no_speech_threshold: Optional[float] = 0.6,
**decode_options,
):
"""
Expand All @@ -50,8 +50,8 @@ def transcribe(
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_captions_threshold: float
If the no_captions probability is higher than this value AND the average log probability
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
decode_options: dict
Expand Down Expand Up @@ -148,7 +148,7 @@ def add_segment(
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_caption_prob": result.no_caption_prob,
"no_speech_prob": result.no_speech_prob,
}
)
if verbose:
Expand All @@ -163,11 +163,11 @@ def add_segment(
result = decode_with_fallback(segment)[0]
tokens = torch.tensor(result.tokens)

if no_captions_threshold is not None:
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_caption_prob > no_captions_threshold
should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_captions_prob
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False

if should_skip:
Expand Down Expand Up @@ -249,7 +249,7 @@ def cli():
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_caption_threshold", type=optional_float, default=0.6, help="if the probability of the <|nocaptions|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")

args = parser.parse_args().__dict__
model_name: str = args.pop("model")
Expand All @@ -261,12 +261,8 @@ def cli():
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
args["language"] = "en"

temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
compression_ratio_threshold = args.pop("compression_ratio_threshold")
logprob_threshold = args.pop("logprob_threshold")
no_caption_threshold = args.pop("no_caption_threshold")

temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
else:
Expand All @@ -276,15 +272,7 @@ def cli():
model = load_model(model_name, device=device)

for audio_path in args.pop("audio"):
result = transcribe(
model,
audio_path,
temperature=temperature,
compression_ratio_threshold=compression_ratio_threshold,
logprob_threshold=logprob_threshold,
no_captions_threshold=no_caption_threshold,
**args,
)
result = transcribe(model, audio_path, temperature=temperature, **args)

audio_basename = os.path.basename(audio_path)

Expand Down

0 comments on commit 15ab548

Please sign in to comment.