diff --git a/whisper/decoding.py b/whisper/decoding.py index fa8cf2dcb..bd6322848 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -108,7 +108,7 @@ class DecodingResult: tokens: List[int] = field(default_factory=list) text: str = "" avg_logprob: float = np.nan - no_caption_prob: float = np.nan + no_speech_prob: float = np.nan temperature: float = np.nan compression_ratio: float = np.nan @@ -543,9 +543,9 @@ def _get_suppress_tokens(self) -> Tuple[int]: suppress_tokens.extend( [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm] ) - if self.tokenizer.no_captions is not None: - # no-captions probability is collected separately - suppress_tokens.append(self.tokenizer.no_captions) + if self.tokenizer.no_speech is not None: + # no-speech probability is collected separately + suppress_tokens.append(self.tokenizer.no_speech) return tuple(sorted(set(suppress_tokens))) @@ -580,15 +580,15 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor): assert audio_features.shape[0] == tokens.shape[0] n_batch = tokens.shape[0] sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) - no_caption_probs = [np.nan] * n_batch + no_speech_probs = [np.nan] * n_batch try: for i in range(self.sample_len): logits = self.inference.logits(tokens, audio_features) - if i == 0 and self.tokenizer.no_captions is not None: # save no_caption_probs + if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) - no_caption_probs = probs_at_sot[:, self.tokenizer.no_captions].tolist() + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() # now we need to consider the logits at the last token only logits = logits[:, -1] @@ -605,7 +605,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor): finally: self.inference.cleanup_caching() - return tokens, sum_logprobs, no_caption_probs + return tokens, sum_logprobs, no_speech_probs @torch.no_grad() def run(self, mel: Tensor) -> List[DecodingResult]: @@ -629,12 +629,12 @@ def run(self, mel: Tensor) -> List[DecodingResult]: tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) # call the main sampling loop - tokens, sum_logprobs, no_caption_probs = self._main_loop(audio_features, tokens) + tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) # reshape the tensors to have (n_audio, n_group) as the first two dimensions audio_features = audio_features[:: self.n_group] - no_caption_probs = no_caption_probs[:: self.n_group] - assert audio_features.shape[0] == len(no_caption_probs) == n_audio + no_speech_probs = no_speech_probs[:: self.n_group] + assert audio_features.shape[0] == len(no_speech_probs) == n_audio tokens = tokens.reshape(n_audio, self.n_group, -1) sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) @@ -653,7 +653,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]: sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] - fields = (texts, languages, tokens, audio_features, avg_logprobs, no_caption_probs) + fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs) if len(set(map(len, fields))) != 1: raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") @@ -664,11 +664,11 @@ def run(self, mel: Tensor) -> List[DecodingResult]: tokens=tokens, text=text, avg_logprob=avg_logprob, - no_caption_prob=no_caption_prob, + no_speech_prob=no_speech_prob, temperature=self.options.temperature, compression_ratio=compression_ratio(text), ) - for text, language, tokens, features, avg_logprob, no_caption_prob in zip(*fields) + for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) ] diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index c115b15ea..5ae691d5f 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -178,8 +178,8 @@ def sot_prev(self) -> int: @property @lru_cache() - def no_captions(self) -> int: - return self._get_single_token_id("<|nocaptions|>") + def no_speech(self) -> int: + return self._get_single_token_id("<|nospeech|>") @property @lru_cache() @@ -283,7 +283,7 @@ def build_tokenizer(name: str = "gpt2"): "<|transcribe|>", "<|startoflm|>", "<|startofprev|>", - "<|nocaptions|>", + "<|nospeech|>", "<|notimestamps|>", ] diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 7000cfab7..5d1ead790 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -23,7 +23,7 @@ def transcribe( temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, logprob_threshold: Optional[float] = -1.0, - no_captions_threshold: Optional[float] = 0.6, + no_speech_threshold: Optional[float] = 0.6, **decode_options, ): """ @@ -50,8 +50,8 @@ def transcribe( logprob_threshold: float If the average log probability over sampled tokens is below this value, treat as failed - no_captions_threshold: float - If the no_captions probability is higher than this value AND the average log probability + no_speech_threshold: float + If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below `logprob_threshold`, consider the segment as silent decode_options: dict @@ -148,7 +148,7 @@ def add_segment( "temperature": result.temperature, "avg_logprob": result.avg_logprob, "compression_ratio": result.compression_ratio, - "no_caption_prob": result.no_caption_prob, + "no_speech_prob": result.no_speech_prob, } ) if verbose: @@ -163,11 +163,11 @@ def add_segment( result = decode_with_fallback(segment)[0] tokens = torch.tensor(result.tokens) - if no_captions_threshold is not None: + if no_speech_threshold is not None: # no voice activity check - should_skip = result.no_caption_prob > no_captions_threshold + should_skip = result.no_speech_prob > no_speech_threshold if logprob_threshold is not None and result.avg_logprob > logprob_threshold: - # don't skip if the logprob is high enough, despite the no_captions_prob + # don't skip if the logprob is high enough, despite the no_speech_prob should_skip = False if should_skip: @@ -249,7 +249,7 @@ def cli(): parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") - parser.add_argument("--no_caption_threshold", type=optional_float, default=0.6, help="if the probability of the <|nocaptions|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") + parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") args = parser.parse_args().__dict__ model_name: str = args.pop("model") @@ -261,12 +261,8 @@ def cli(): warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") args["language"] = "en" - temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") - compression_ratio_threshold = args.pop("compression_ratio_threshold") - logprob_threshold = args.pop("logprob_threshold") - no_caption_threshold = args.pop("no_caption_threshold") - temperature = args.pop("temperature") + temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") if temperature_increment_on_fallback is not None: temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)) else: @@ -276,15 +272,7 @@ def cli(): model = load_model(model_name, device=device) for audio_path in args.pop("audio"): - result = transcribe( - model, - audio_path, - temperature=temperature, - compression_ratio_threshold=compression_ratio_threshold, - logprob_threshold=logprob_threshold, - no_captions_threshold=no_caption_threshold, - **args, - ) + result = transcribe(model, audio_path, temperature=temperature, **args) audio_basename = os.path.basename(audio_path)