diff --git a/utils_vad.py b/utils_vad.py new file mode 100644 index 0000000..aaa8d15 --- /dev/null +++ b/utils_vad.py @@ -0,0 +1,566 @@ +# MIT License + +# Copyright (c) 2020-present Silero Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import torchaudio +from typing import Callable, List +import torch.nn.functional as F +import warnings + +languages = ['ru', 'en', 'de', 'es'] + + +class OnnxWrapper(): + + def __init__(self, path, force_onnx_cpu=False): + import numpy as np + global np + import onnxruntime + + opts = onnxruntime.SessionOptions() + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + + if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): + self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts) + else: + self.session = onnxruntime.InferenceSession(path, sess_options=opts) + + self.reset_states() + self.sample_rates = [8000, 16000] + + def _validate_input(self, x, sr: int): + if x.dim() == 1: + x = x.unsqueeze(0) + if x.dim() > 2: + raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}") + + if sr != 16000 and (sr % 16000 == 0): + step = sr // 16000 + x = x[:,::step] + sr = 16000 + + if sr not in self.sample_rates: + raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)") + + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + return x, sr + + def reset_states(self, batch_size=1): + self._h = np.zeros((2, batch_size, 64)).astype('float32') + self._c = np.zeros((2, batch_size, 64)).astype('float32') + self._last_sr = 0 + self._last_batch_size = 0 + + def __call__(self, x, sr: int): + + x, sr = self._validate_input(x, sr) + batch_size = x.shape[0] + + if not self._last_batch_size: + self.reset_states(batch_size) + if (self._last_sr) and (self._last_sr != sr): + self.reset_states(batch_size) + if (self._last_batch_size) and (self._last_batch_size != batch_size): + self.reset_states(batch_size) + + if sr in [8000, 16000]: + ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')} + ort_outs = self.session.run(None, ort_inputs) + out, self._h, self._c = ort_outs + else: + raise ValueError() + + self._last_sr = sr + self._last_batch_size = batch_size + + out = torch.tensor(out) + return out + + def audio_forward(self, x, sr: int, num_samples: int = 512): + outs = [] + x, sr = self._validate_input(x, sr) + + if x.shape[1] % num_samples: + pad_num = num_samples - (x.shape[1] % num_samples) + x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0) + + self.reset_states(x.shape[0]) + for i in range(0, x.shape[1], num_samples): + wavs_batch = x[:, i:i+num_samples] + out_chunk = self.__call__(wavs_batch, sr) + outs.append(out_chunk) + + stacked = torch.cat(outs, dim=1) + return stacked.cpu() + + +class Validator(): + def __init__(self, url, force_onnx_cpu): + self.onnx = True if url.endswith('.onnx') else False + torch.hub.download_url_to_file(url, 'inf.model') + if self.onnx: + import onnxruntime + if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): + self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider']) + else: + self.model = onnxruntime.InferenceSession('inf.model') + else: + self.model = init_jit_model(model_path='inf.model') + + def __call__(self, inputs: torch.Tensor): + with torch.no_grad(): + if self.onnx: + ort_inputs = {'input': inputs.cpu().numpy()} + outs = self.model.run(None, ort_inputs) + outs = [torch.Tensor(x) for x in outs] + else: + outs = self.model(inputs) + + return outs + + +def read_audio(path: str, + sampling_rate: int = 16000): + + sox_backends = set(['sox', 'sox_io']) + audio_backends = torchaudio.list_audio_backends() + + if len(sox_backends.intersection(audio_backends)) > 0: + effects = [ + ['channels', '1'], + ['rate', str(sampling_rate)] + ] + + wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects) + else: + wav, sr = torchaudio.load(path) + + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + + if sr != sampling_rate: + transform = torchaudio.transforms.Resample(orig_freq=sr, + new_freq=sampling_rate) + wav = transform(wav) + sr = sampling_rate + + assert sr == sampling_rate + return wav.squeeze(0) + + +def save_audio(path: str, + tensor: torch.Tensor, + sampling_rate: int = 16000): + torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16) + + +def init_jit_model(model_path: str, + device=torch.device('cpu')): + torch.set_grad_enabled(False) + model = torch.jit.load(model_path, map_location=device) + model.eval() + return model + + +def make_visualization(probs, step): + import pandas as pd + pd.DataFrame({'probs': probs}, + index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8), + kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step], + xlabel='seconds', + ylabel='speech probability', + colormap='tab20') + + +def get_speech_timestamps(audio: torch.Tensor, + model, + threshold: float = 0.5, + sampling_rate: int = 16000, + min_speech_duration_ms: int = 250, + max_speech_duration_s: float = float('inf'), + min_silence_duration_ms: int = 100, + window_size_samples: int = 512, + speech_pad_ms: int = 30, + return_seconds: bool = False, + visualize_probs: bool = False, + progress_tracking_callback: Callable[[float], None] = None): + + """ + This method is used for splitting long audios into speech chunks using silero VAD + + Parameters + ---------- + audio: torch.Tensor, one dimensional + One dimensional float torch.Tensor, other types are casted to torch if possible + + model: preloaded .jit silero VAD model + + threshold: float (default - 0.5) + Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. + It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. + + sampling_rate: int (default - 16000) + Currently silero VAD models support 8000 and 16000 sample rates + + min_speech_duration_ms: int (default - 250 milliseconds) + Final speech chunks shorter min_speech_duration_ms are thrown out + + max_speech_duration_s: int (default - inf) + Maximum duration of speech chunks in seconds + Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting. + Otherwise, they will be split aggressively just before max_speech_duration_s. + + min_silence_duration_ms: int (default - 100 milliseconds) + In the end of each speech chunk wait for min_silence_duration_ms before separating it + + window_size_samples: int (default - 1536 samples) + Audio chunks of window_size_samples size are fed to the silero VAD model. + WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate. + Values other than these may affect model perfomance!! + + speech_pad_ms: int (default - 30 milliseconds) + Final speech chunks are padded by speech_pad_ms each side + + return_seconds: bool (default - False) + whether return timestamps in seconds (default - samples) + + visualize_probs: bool (default - False) + whether draw prob hist or not + + progress_tracking_callback: Callable[[float], None] (default - None) + callback function taking progress in percents as an argument + + Returns + ---------- + speeches: list of dicts + list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds) + """ + + if not torch.is_tensor(audio): + try: + audio = torch.Tensor(audio) + except: + raise TypeError("Audio cannot be casted to tensor. Cast it manually") + + if len(audio.shape) > 1: + for i in range(len(audio.shape)): # trying to squeeze empty dimensions + audio = audio.squeeze(0) + if len(audio.shape) > 1: + raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?") + + if sampling_rate > 16000 and (sampling_rate % 16000 == 0): + step = sampling_rate // 16000 + sampling_rate = 16000 + audio = audio[::step] + warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!') + else: + step = 1 + + if sampling_rate == 8000 and window_size_samples > 768: + warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!') + if window_size_samples not in [256, 512, 768, 1024, 1536]: + warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate') + + model.reset_states() + min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 + speech_pad_samples = sampling_rate * speech_pad_ms / 1000 + max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples + min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 + min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 + + audio_length_samples = len(audio) + + speech_probs = [] + for current_start_sample in range(0, audio_length_samples, window_size_samples): + chunk = audio[current_start_sample: current_start_sample + window_size_samples] + if len(chunk) < window_size_samples: + chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) + speech_prob = model(chunk, sampling_rate).item() + speech_probs.append(speech_prob) + # caculate progress and seng it to callback function + progress = current_start_sample + window_size_samples + if progress > audio_length_samples: + progress = audio_length_samples + progress_percent = (progress / audio_length_samples) * 100 + if progress_tracking_callback: + progress_tracking_callback(progress_percent) + + triggered = False + speeches = [] + current_speech = {} + neg_threshold = threshold - 0.15 + temp_end = 0 # to save potential segment end (and tolerate some silence) + prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached + + for i, speech_prob in enumerate(speech_probs): + if (speech_prob >= threshold) and temp_end: + temp_end = 0 + if next_start < prev_end: + next_start = window_size_samples * i + + if (speech_prob >= threshold) and not triggered: + triggered = True + current_speech['start'] = window_size_samples * i + continue + + if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples: + if prev_end: + current_speech['end'] = prev_end + speeches.append(current_speech) + current_speech = {} + if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres) + triggered = False + else: + current_speech['start'] = next_start + prev_end = next_start = temp_end = 0 + else: + current_speech['end'] = window_size_samples * i + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if (speech_prob < neg_threshold) and triggered: + if not temp_end: + temp_end = window_size_samples * i + if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence + prev_end = temp_end + if (window_size_samples * i) - temp_end < min_silence_samples: + continue + else: + current_speech['end'] = temp_end + if (current_speech['end'] - current_speech['start']) > min_speech_samples: + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples: + current_speech['end'] = audio_length_samples + speeches.append(current_speech) + + for i, speech in enumerate(speeches): + if i == 0: + speech['start'] = int(max(0, speech['start'] - speech_pad_samples)) + if i != len(speeches) - 1: + silence_duration = speeches[i+1]['start'] - speech['end'] + if silence_duration < 2 * speech_pad_samples: + speech['end'] += int(silence_duration // 2) + speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2)) + else: + speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) + speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples)) + else: + speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) + + if return_seconds: + for speech_dict in speeches: + speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1) + speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1) + elif step > 1: + for speech_dict in speeches: + speech_dict['start'] *= step + speech_dict['end'] *= step + + if visualize_probs: + make_visualization(speech_probs, window_size_samples / sampling_rate) + + return speeches + + +def get_number_ts(wav: torch.Tensor, + model, + model_stride=8, + hop_length=160, + sample_rate=16000): + wav = torch.unsqueeze(wav, dim=0) + perframe_logits = model(wav)[0] + perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided) + extended_preds = [] + for i in perframe_preds: + extended_preds.extend([i.item()] * model_stride) + # len(extended_preds) is *num_frames_real*; for each frame of audio we know if it has a number in it. + triggered = False + timings = [] + cur_timing = {} + for i, pred in enumerate(extended_preds): + if pred == 1: + if not triggered: + cur_timing['start'] = int((i * hop_length) / (sample_rate / 1000)) + triggered = True + elif pred == 0: + if triggered: + cur_timing['end'] = int((i * hop_length) / (sample_rate / 1000)) + timings.append(cur_timing) + cur_timing = {} + triggered = False + if cur_timing: + cur_timing['end'] = int(len(wav) / (sample_rate / 1000)) + timings.append(cur_timing) + return timings + + +def get_language(wav: torch.Tensor, + model): + wav = torch.unsqueeze(wav, dim=0) + lang_logits = model(wav)[2] + lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1 + assert lang_pred < len(languages) + return languages[lang_pred] + + +def get_language_and_group(wav: torch.Tensor, + model, + lang_dict: dict, + lang_group_dict: dict, + top_n=1): + wav = torch.unsqueeze(wav, dim=0) + lang_logits, lang_group_logits = model(wav) + + softm = torch.softmax(lang_logits, dim=1).squeeze() + softm_group = torch.softmax(lang_group_logits, dim=1).squeeze() + + srtd = torch.argsort(softm, descending=True) + srtd_group = torch.argsort(softm_group, descending=True) + + outs = [] + outs_group = [] + for i in range(top_n): + prob = round(softm[srtd[i]].item(), 2) + prob_group = round(softm_group[srtd_group[i]].item(), 2) + outs.append((lang_dict[str(srtd[i].item())], prob)) + outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group)) + + return outs, outs_group + + +class VADIterator: + def __init__(self, + model, + threshold: float = 0.5, + sampling_rate: int = 16000, + min_silence_duration_ms: int = 100, + speech_pad_ms: int = 30 + ): + + """ + Class for stream imitation + + Parameters + ---------- + model: preloaded .jit silero VAD model + + threshold: float (default - 0.5) + Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. + It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. + + sampling_rate: int (default - 16000) + Currently silero VAD models support 8000 and 16000 sample rates + + min_silence_duration_ms: int (default - 100 milliseconds) + In the end of each speech chunk wait for min_silence_duration_ms before separating it + + speech_pad_ms: int (default - 30 milliseconds) + Final speech chunks are padded by speech_pad_ms each side + """ + + self.model = model + self.threshold = threshold + self.sampling_rate = sampling_rate + + if sampling_rate not in [8000, 16000]: + raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]') + + self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 + self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 + self.reset_states() + + def reset_states(self): + + self.model.reset_states() + self.triggered = False + self.temp_end = 0 + self.current_sample = 0 + + def __call__(self, x, return_seconds=False): + """ + x: torch.Tensor + audio chunk (see examples in repo) + + return_seconds: bool (default - False) + whether return timestamps in seconds (default - samples) + """ + + if not torch.is_tensor(x): + try: + x = torch.Tensor(x) + except: + raise TypeError("Audio cannot be casted to tensor. Cast it manually") + + window_size_samples = len(x[0]) if x.dim() == 2 else len(x) + self.current_sample += window_size_samples + + speech_prob = self.model(x, self.sampling_rate).item() + + if (speech_prob >= self.threshold) and self.temp_end: + self.temp_end = 0 + + if (speech_prob >= self.threshold) and not self.triggered: + self.triggered = True + speech_start = self.current_sample - self.speech_pad_samples - window_size_samples + return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)} + + if (speech_prob < self.threshold - 0.15) and self.triggered: + if not self.temp_end: + self.temp_end = self.current_sample + if self.current_sample - self.temp_end < self.min_silence_samples: + return None + else: + speech_end = self.temp_end + self.speech_pad_samples - window_size_samples + self.temp_end = 0 + self.triggered = False + return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)} + + return None + + +def collect_chunks(tss: List[dict], + wav: torch.Tensor): + chunks = [] + for i in tss: + chunks.append(wav[i['start']: i['end']]) + return torch.cat(chunks) + + +def drop_chunks(tss: List[dict], + wav: torch.Tensor): + chunks = [] + cur_start = 0 + for i in tss: + chunks.append((wav[cur_start: i['start']])) + cur_start = i['end'] + return torch.cat(chunks)