From 04d8af31066b7666191028328fd67600584b6b7e Mon Sep 17 00:00:00 2001 From: osmr Date: Fri, 4 Jun 2021 20:54:02 +0300 Subject: [PATCH] Work on LibriSpeech DS --- eval_pt.py | 3 +- pytorch/dataset_utils.py | 2 + pytorch/datasets/librispeech_asr_dataset.py | 265 ++++++++++++++++++++ pytorch/metrics/asr_metrics.py | 123 +++++++++ pytorch/pytorchcv/models/jasper.py | 259 ++++++++++++++++++- pytorch/pytorchcv/models/quartznet.py | 5 - pytorch/utils.py | 3 + 7 files changed, 648 insertions(+), 12 deletions(-) create mode 100644 pytorch/datasets/librispeech_asr_dataset.py create mode 100644 pytorch/metrics/asr_metrics.py diff --git a/eval_pt.py b/eval_pt.py index 90fa1f479..c219ee201 100644 --- a/eval_pt.py +++ b/eval_pt.py @@ -133,7 +133,7 @@ def parse_args(): type=str, default="ImageNet1K", help="dataset name. options are ImageNet1K, CUB200_2011, CIFAR10, CIFAR100, SVHN, VOC2012, ADE20K, Cityscapes, " - "COCO") + "COCO, LibriSpeech") parser.add_argument( "--work-dir", type=str, @@ -425,6 +425,7 @@ def main(): "cocoseg": "CocoSeg", "cocohpe": "CocoHpe", "hp": "HPatches", + "ls": "LibriSpeech", } for model_name, model_metainfo in (_model_sha1.items() if version_info[0] >= 3 else _model_sha1.iteritems()): error, checksum, repo_release_tag, caption, paper, ds, img_size, scale, batch, rem = model_metainfo diff --git a/pytorch/dataset_utils.py b/pytorch/dataset_utils.py index 704a29a36..5d193f1b7 100644 --- a/pytorch/dataset_utils.py +++ b/pytorch/dataset_utils.py @@ -18,6 +18,7 @@ from .datasets.coco_hpe2_dataset import CocoHpe2MetaInfo from .datasets.coco_hpe3_dataset import CocoHpe3MetaInfo from .datasets.hpatches_mch_dataset import HPatchesMetaInfo +from .datasets.librispeech_asr_dataset import LibriSpeechMetaInfo from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler @@ -51,6 +52,7 @@ def get_dataset_metainfo(dataset_name): "CocoHpe2": CocoHpe2MetaInfo, "CocoHpe3": CocoHpe3MetaInfo, "HPatches": HPatchesMetaInfo, + "LibriSpeech": LibriSpeechMetaInfo, } if dataset_name in dataset_metainfo_map.keys(): return dataset_metainfo_map[dataset_name]() diff --git a/pytorch/datasets/librispeech_asr_dataset.py b/pytorch/datasets/librispeech_asr_dataset.py new file mode 100644 index 000000000..e1cede53c --- /dev/null +++ b/pytorch/datasets/librispeech_asr_dataset.py @@ -0,0 +1,265 @@ +""" + LibriSpeech ASR dataset. +""" + +import os +import numpy as np +# import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from .dataset_metainfo import DatasetMetaInfo + + +class LibriSpeech(data.Dataset): + """ + LibriSpeech ASR dataset. + + Parameters: + ---------- + root : str, default '~/.torch/datasets/LibriSpeech' + Path to the folder stored the dataset. + mode : str, default 'test' + 'test-clean'. + transform : function, default None + A function that takes data and transforms it. + target_transform : function, default None + A function that takes label and transforms it. + """ + def __init__(self, + root=os.path.join("~", ".torch", "datasets", "LibriSpeech"), + mode="test", + transform=None, + target_transform=None): + super(LibriSpeech, self).__init__() + self._transform = transform + self._target_transform = target_transform + self.data = [] + + vocabulary = [' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', + 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'"] + vocabulary_dict = {c: i for i, c in enumerate(vocabulary)} + + import soundfile + + root_dir_path = os.path.expanduser(root) + assert os.path.exists(root_dir_path) + + if mode in ("val", "test"): + mode = "dev-clean" + + data_dir_path = os.path.join(root_dir_path, mode) + assert os.path.exists(data_dir_path) + + for speaker_id in os.listdir(data_dir_path): + speaker_dir_path = os.path.join(data_dir_path, speaker_id) + for chapter_id in os.listdir(speaker_dir_path): + chapter_dir_path = os.path.join(speaker_dir_path, chapter_id) + transcript_file_path = os.path.join(chapter_dir_path, "{}-{}.trans.txt".format(speaker_id, chapter_id)) + with open(transcript_file_path, "r") as f: + transcripts = dict(x.split(" ", maxsplit=1) for x in f.readlines()) + for flac_file_name in os.listdir(chapter_dir_path): + if flac_file_name.endswith(".flac"): + wav_file_name = flac_file_name.replace(".flac", ".wav") + wav_file_path = os.path.join(chapter_dir_path, wav_file_name) + if not os.path.exists(wav_file_path): + flac_file_path = os.path.join(chapter_dir_path, flac_file_name) + pcm, sample_rate = soundfile.read(flac_file_path) + soundfile.write(wav_file_path, pcm, sample_rate) + text = transcripts[wav_file_name.replace(".wav", "")] + text = text.strip("\n ").lower() + text = np.array([vocabulary_dict[c] for c in text], dtype=np.long) + self.data.append((wav_file_path, text)) + + self.preprocessor = NemoMelSpecExtractor(dither=0.0) + + def __getitem__(self, index): + wav_file_path, label_text = self.data[index] + + audio_data_list = self.read_audio([wav_file_path]) + x_np, x_np_len = self.preprocessor(audio_data_list) + + return (x_np[0], x_np_len[0]), label_text + + def __len__(self): + return len(self.data) + + @staticmethod + def read_audio(audio_file_paths): + """ + Read audio. + + Parameters: + ---------- + audio_file_paths : list of str + Paths to audio files. + + Returns: + ------- + list of np.array + Audio data. + """ + desired_audio_sample_rate = 16000 + + from soundfile import SoundFile + + audio_data_list = [] + for audio_file_path in audio_file_paths: + with SoundFile(audio_file_path, "r") as data: + sample_rate = data.samplerate + audio_data = data.read(dtype="float32") + audio_data = audio_data.transpose() + if desired_audio_sample_rate != sample_rate: + from librosa.core import resample as lr_resample + audio_data = lr_resample(y=audio_data, orig_sr=sample_rate, target_sr=desired_audio_sample_rate) + if audio_data.ndim >= 2: + audio_data = np.mean(audio_data, axis=1) + audio_data_list.append(audio_data) + + return audio_data_list + + +class NemoMelSpecExtractor(object): + """ + Mel-Spectrogram Extractor from NVIDIA NEMO toolkit. + + Parameters: + ---------- + sample_rate : int, default 16000 + Sample rate of the input audio data. + window_size_sec : float, default 0.02 + Size of window for FFT in seconds. + window_stride_sec : float, default 0.01 + Stride of window for FFT in seconds. + n_fft : int, default 512 + Length of FT window. + n_filters : int, default 64 + Number of Mel spectrogram freq bins. + preemph : float, default 0.97 + Amount of pre emphasis to add to audio. + dither : float, default 1.0e-05 + Amount of white-noise dithering. + """ + def __init__(self, + sample_rate=16000, + window_size_sec=0.02, + window_stride_sec=0.01, + n_fft=512, + n_filters=64, + preemph=0.97, + dither=1.0e-05, + **kwargs): + super(NemoMelSpecExtractor, self).__init__(**kwargs) + self.log_zero_guard_value = 2 ** -24 + win_length = int(window_size_sec * sample_rate) + self.hop_length = int(window_stride_sec * sample_rate) + + from scipy import signal as scipy_signal + from librosa import stft as librosa_stft + from librosa.filters import mel as librosa_mel + + window_fn = scipy_signal.hann(win_length, sym=True) + self.stft = lambda x: librosa_stft( + x, + n_fft=n_fft, + hop_length=self.hop_length, + win_length=win_length, + window=window_fn, + center=True) + + self.dither = dither + self.preemph = preemph + + self.pad_align = 16 + + self.filter_bank = librosa_mel( + sample_rate, + n_fft, + n_mels=n_filters, + fmin=0, + fmax=(sample_rate / 2)) + + def __call__(self, xs): + """ + Preprocess audio. + + Parameters: + ---------- + xs : list of np.array + Audio data. + + Returns: + ------- + np.array + Audio data. + np.array + Audio data lengths. + """ + x_eps = 1e-5 + + batch = len(xs) + x_len = np.zeros((batch,), dtype=np.long) + + ys = [] + for i, xi in enumerate(xs): + x_len[i] = np.ceil(float(len(xi)) / self.hop_length).astype(np.long) + + if self.dither > 0: + xi += self.dither * np.random.randn(*xi.shape) + + xi = np.concatenate((xi[:1], xi[1:] - self.preemph * xi[:-1]), axis=0) + + yi = self.stft(xi) + yi = np.abs(yi) + yi = np.square(yi) + yi = np.matmul(self.filter_bank, yi) + yi = np.log(yi + self.log_zero_guard_value) + + assert (yi.shape[1] != 1) + yi_mean = yi.mean(axis=1) + yi_std = yi.std(axis=1) + yi_std += x_eps + yi = (yi - np.expand_dims(yi_mean, axis=-1)) / np.expand_dims(yi_std, axis=-1) + + ys.append(yi) + + channels = ys[0].shape[0] + x_len_max = max([yi.shape[-1] for yi in ys]) + x = np.zeros((batch, channels, x_len_max), dtype=np.float32) + for i, yi in enumerate(ys): + x_len_i = x_len[i] + x[i, :, :x_len_i] = yi[:, :x_len_i] + + pad_rem = x_len_max % self.pad_align + if pad_rem != 0: + x = np.pad(x, ((0, 0), (0, 0), (0, self.pad_align - pad_rem))) + + return x, x_len + + +def ls_test_transform(ds_metainfo): + assert (ds_metainfo is not None) + return transforms.Compose([ + transforms.ToTensor(), + ]) + + +class LibriSpeechMetaInfo(DatasetMetaInfo): + def __init__(self): + super(LibriSpeechMetaInfo, self).__init__() + self.label = "LibriSpeech" + self.short_label = "ls" + self.root_dir_name = "LibriSpeech" + self.dataset_class = LibriSpeech + self.ml_type = "asr" + self.num_classes = 29 + self.vocabulary = [' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'"] + self.val_metric_extra_kwargs = [{"vocabulary": self.vocabulary}] + self.val_metric_capts = ["Val.WER"] + self.val_metric_names = ["WER"] + self.test_metric_extra_kwargs = [{}] + self.test_metric_capts = ["Test.WER"] + self.test_metric_names = ["WER"] + self.val_transform = ls_test_transform + self.test_transform = ls_test_transform + self.saver_acc_ind = 0 diff --git a/pytorch/metrics/asr_metrics.py b/pytorch/metrics/asr_metrics.py new file mode 100644 index 000000000..ea5db4d29 --- /dev/null +++ b/pytorch/metrics/asr_metrics.py @@ -0,0 +1,123 @@ +""" +Evaluation Metrics for Automatic Speech Recognition (ASR). +""" + +from .metric import EvalMetric + +__all__ = ['WER'] + + +class WER(EvalMetric): + """ + Computes Word Error Rate (WER) for Automatic Speech Recognition (ASR). + + Parameters: + ---------- + vocabulary : list of str + Vocabulary of the dataset. + axis : int, default 1 + The axis that represents classes + name : str, default 'accuracy' + Name of this metric instance for display. + output_names : list of str, or None, default None + Name of predictions that should be used when updating with update_dict. + By default include all predictions. + label_names : list of str, or None, default None + Name of labels that should be used when updating with update_dict. + By default include all labels. + """ + def __init__(self, + vocabulary, + axis=1, + name="wer", + output_names=None, + label_names=None): + super(WER, self).__init__( + name, + axis=axis, + output_names=output_names, + label_names=label_names, + has_global_stats=True) + self.axis = axis + self.vocabulary = vocabulary + self.ctc_decoder = CtcDecoder(vocabulary=vocabulary) + + def update(self, labels, preds): + """ + Updates the internal evaluation result. + + Parameters: + ---------- + labels : torch.Tensor + The labels of the data with class indices as values, one per sample. + preds : torch.Tensor + Prediction values for samples. Each prediction value can either be the class index, + or a vector of likelihoods for all classes. + """ + import editdistance + + labels_code = labels.cpu().numpy() + labels = [] + for label_code in labels_code: + label_text = "".join([self.ctc_decoder.labels_map[c] for c in label_code]) + labels.append(label_text) + + preds = preds[0] + greedy_predictions = preds.transpose(1, 2).log_softmax(dim=-1).argmax(dim=-1, keepdim=False).cpu().numpy() + preds = self.ctc_decoder(greedy_predictions) + + assert (len(labels) == len(preds)) + for pred, label in zip(preds, labels): + pred = pred.split() + label = label.split() + + word_error_count = editdistance.eval(label, pred) + word_count = len(label) + + self.sum_metric += word_error_count + self.global_sum_metric += word_error_count + self.num_inst += word_count + self.global_num_inst += word_count + + +class CtcDecoder(object): + """ + CTC decoder (to decode a sequence of labels to words). + + Parameters: + ---------- + vocabulary : list of str + Vocabulary of the dataset. + """ + def __init__(self, + vocabulary): + super().__init__() + self.blank_id = len(vocabulary) + self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) + + def __call__(self, + predictions): + """ + Decode a sequence of labels to words. + + Parameters: + ---------- + predictions : np.array of int or list of list of int + Tensor with predicted labels. + + Returns: + ------- + list of str + Words. + """ + hypotheses = [] + for prediction in predictions: + decoded_prediction = [] + previous = self.blank_id + for p in prediction: + if (p != previous or previous == self.blank_id) and p != self.blank_id: + decoded_prediction.append(p) + previous = p + hypothesis = "".join([self.labels_map[c] for c in decoded_prediction]) + hypotheses.append(hypothesis) + return hypotheses diff --git a/pytorch/pytorchcv/models/jasper.py b/pytorch/pytorchcv/models/jasper.py index 7ba31edc1..6b0d2023d 100644 --- a/pytorch/pytorchcv/models/jasper.py +++ b/pytorch/pytorchcv/models/jasper.py @@ -3,9 +3,11 @@ Original paper: 'Jasper: An End-to-End Convolutional Neural Acoustic Model,' https://arxiv.org/abs/1904.03288. """ -__all__ = ['Jasper', 'jasper5x3', 'jasper10x4', 'jasper10x5', 'get_jasper', 'CtcDecoder'] +__all__ = ['Jasper', 'jasper5x3', 'jasper10x4', 'jasper10x5', 'get_jasper', 'CtcDecoder', 'NemoMelSpecExtractor', + 'JasperTranscriber'] import os +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -55,6 +57,214 @@ def __call__(self, return hypotheses +class NemoMelSpecExtractor(object): + """ + Mel-Spectrogram Extractor from NVIDIA NEMO toolkit. + + Parameters: + ---------- + sample_rate : int, default 16000 + Sample rate of the input audio data. + window_size_sec : float, default 0.02 + Size of window for FFT in seconds. + window_stride_sec : float, default 0.01 + Stride of window for FFT in seconds. + n_fft : int, default 512 + Length of FT window. + n_filters : int, default 64 + Number of Mel spectrogram freq bins. + preemph : float, default 0.97 + Amount of pre emphasis to add to audio. + dither : float, default 1.0e-05 + Amount of white-noise dithering. + """ + def __init__(self, + sample_rate=16000, + window_size_sec=0.02, + window_stride_sec=0.01, + n_fft=512, + n_filters=64, + preemph=0.97, + dither=1.0e-05, + **kwargs): + super(NemoMelSpecExtractor, self).__init__(**kwargs) + self.log_zero_guard_value = 2 ** -24 + win_length = int(window_size_sec * sample_rate) + self.hop_length = int(window_stride_sec * sample_rate) + + from scipy import signal as scipy_signal + from librosa import stft as librosa_stft + from librosa.filters import mel as librosa_mel + + window_fn = scipy_signal.hann(win_length, sym=True) + self.stft = lambda x: librosa_stft( + x, + n_fft=n_fft, + hop_length=self.hop_length, + win_length=win_length, + window=window_fn, + center=True) + + self.dither = dither + self.preemph = preemph + + self.pad_align = 16 + + self.filter_bank = librosa_mel( + sample_rate, + n_fft, + n_mels=n_filters, + fmin=0, + fmax=(sample_rate / 2)) + + def __call__(self, xs): + """ + Preprocess audio. + + Parameters: + ---------- + xs : list of np.array + Audio data. + + Returns: + ------- + np.array + Audio data. + np.array + Audio data lengths. + """ + x_eps = 1e-5 + + batch = len(xs) + x_len = np.zeros((batch,), dtype=np.long) + + ys = [] + for i, xi in enumerate(xs): + x_len[i] = np.ceil(float(len(xi)) / self.hop_length).astype(np.long) + + if self.dither > 0: + xi += self.dither * np.random.randn(*xi.shape) + + xi = np.concatenate((xi[:1], xi[1:] - self.preemph * xi[:-1]), axis=0) + + yi = self.stft(xi) + yi = np.abs(yi) + yi = np.square(yi) + yi = np.matmul(self.filter_bank, yi) + yi = np.log(yi + self.log_zero_guard_value) + + assert (yi.shape[1] != 1) + yi_mean = yi.mean(axis=1) + yi_std = yi.std(axis=1) + yi_std += x_eps + yi = (yi - np.expand_dims(yi_mean, axis=-1)) / np.expand_dims(yi_std, axis=-1) + + ys.append(yi) + + channels = ys[0].shape[0] + x_len_max = max([yi.shape[-1] for yi in ys]) + x = np.zeros((batch, channels, x_len_max), dtype=np.float32) + for i, yi in enumerate(ys): + x_len_i = x_len[i] + x[i, :, :x_len_i] = yi[:, :x_len_i] + + pad_rem = x_len_max % self.pad_align + if pad_rem != 0: + x = np.pad(x, ((0, 0), (0, 0), (0, self.pad_align - pad_rem))) + + return x, x_len + + +class JasperTranscriber(object): + """ + Jasper/DR/QuartzNet model transcriber. + + Parameters: + ---------- + net : nn.Module + Network. + use_cuda : bool + Whether to use CUDA. + dither : float, default 0.0 + Amount of white-noise dithering. + """ + def __init__(self, + net, + use_cuda, + dither=0.0): + super(JasperTranscriber, self).__init__() + self.use_cuda = use_cuda + self.net = net + self.preprocessor = NemoMelSpecExtractor(dither=dither) + self.ctc_decoder = CtcDecoder(vocabulary=net.vocabulary) + + def __call__(self, audio_file_paths): + """ + Transcribe audio. + + Parameters: + ---------- + audio_file_paths : list of str + Paths to audio files. + + Returns: + ------- + list of str + Transcribed texts. + """ + assert (type(audio_file_paths) in [list, tuple]) + + audio_data_list = self.read_audio(audio_file_paths) + x_np, x_np_len = self.preprocessor(audio_data_list) + + x = torch.from_numpy(x_np) + x_len = torch.from_numpy(x_np_len) + if self.use_cuda: + x = x.cuda() + x_len = x_len.cuda() + + y, y_len = self.net(x, x_len) + + greedy_predictions = y.transpose(1, 2).log_softmax(dim=-1).argmax(dim=-1, keepdim=False).cpu().numpy() + ctc_predictions = self.ctc_decoder(greedy_predictions) + + return ctc_predictions + + @staticmethod + def read_audio(audio_file_paths): + """ + Read audio. + + Parameters: + ---------- + audio_file_paths : list of str + Paths to audio files. + + Returns: + ------- + list of np.array + Audio data. + """ + desired_audio_sample_rate = 16000 + + from soundfile import SoundFile + + audio_data_list = [] + for audio_file_path in audio_file_paths: + with SoundFile(audio_file_path, "r") as data: + sample_rate = data.samplerate + audio_data = data.read(dtype="float32") + audio_data = audio_data.transpose() + if desired_audio_sample_rate != sample_rate: + from librosa.core import resample as lr_resample + audio_data = lr_resample(y=audio_data, orig_sr=sample_rate, target_sr=desired_audio_sample_rate) + if audio_data.ndim >= 2: + audio_data = np.mean(audio_data, axis=1) + audio_data_list.append(audio_data) + + return audio_data_list + + def conv1d1(in_channels, out_channels, stride=1, @@ -589,6 +799,10 @@ class Jasper(nn.Module): Whether to use depthwise block. use_dr : bool Whether to use dense residual scheme. + return_text : bool, default False + Whether to return text instead of logits. + vocabulary : list of str or None, default None + Vocabulary of the dataset. in_channels : int, default 64 Number of input channels (audio features). num_classes : int, default 29 @@ -602,11 +816,15 @@ def __init__(self, repeat, use_dw, use_dr, + return_text=False, + vocabulary=None, in_channels=64, num_classes=29): super(Jasper, self).__init__() self.in_size = None self.num_classes = num_classes + self.vocabulary = vocabulary + self.return_text = return_text self.features = DualPathSequential() init_block_class = DwsConvBlock1d if use_dw else MaskConvBlock1d @@ -648,6 +866,9 @@ def __init__(self, out_channels=num_classes, bias=True) + if self.return_text: + self.ctc_decoder = CtcDecoder(vocabulary=vocabulary) + self._init_params() def _init_params(self): @@ -657,10 +878,38 @@ def _init_params(self): if module.bias is not None: nn.init.constant_(module.bias, 0) - def forward(self, x, x_len): + def forward(self, x, x_len=None): + if x_len is None: + assert (type(x) in (list, tuple)) + x, x_len = x x, x_len = self.features(x, x_len) x = self.output(x) - return x, x_len + + if self.return_text: + greedy_predictions = x.transpose(1, 2).log_softmax(dim=-1).argmax(dim=-1, keepdim=False).cpu().numpy() + return self.ctc_decoder(greedy_predictions) + else: + return x, x_len + + def transcript(self, audio_file_paths, use_cuda=False): + """ + Transcribe audio. + + Parameters: + ---------- + audio_file_paths : list of str + Paths to audio files. + + Returns: + ------- + list of str + Transcribed texts. + use_cuda : bool, default False + Whether to use CUDA. + """ + jt = JasperTranscriber(self, use_cuda=use_cuda) + v = jt(audio_file_paths) + return v def get_jasper(version, @@ -725,10 +974,9 @@ def get_jasper(version, repeat=repeat, use_dw=use_dw, use_dr=use_dr, + vocabulary=vocabulary, **kwargs) - net.vocabulary = vocabulary - if pretrained: if (model_name is None) or (not model_name): raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") @@ -796,7 +1044,6 @@ def _calc_width(net): def _test(): - import numpy as np import torch pretrained = False diff --git a/pytorch/pytorchcv/models/quartznet.py b/pytorch/pytorchcv/models/quartznet.py index ee92ca279..66797901c 100644 --- a/pytorch/pytorchcv/models/quartznet.py +++ b/pytorch/pytorchcv/models/quartznet.py @@ -8,7 +8,6 @@ 'quartznet15x5_ru'] from .jasper import get_jasper -# from .jasper import CtcDecoder def quartznet5x5_en_ls(num_classes=29, **kwargs): @@ -185,10 +184,6 @@ def _test(): assert (tuple(y.size())[:2] == (batch, net.num_classes)) assert (y.size()[2] in [seq_len_max // 2, seq_len_max // 2 + 1]) - # if net.vocabulary is not None: - # greedy_predictions = y.transpose(1, 2).log_softmax(dim=-1).argmax(dim=-1, keepdim=False).cpu().numpy() - # ctc_predictions = CtcDecoder(vocabulary=net.vocabulary)(greedy_predictions) - if __name__ == "__main__": _test() diff --git a/pytorch/utils.py b/pytorch/utils.py index 37c017dd5..54f0855b9 100644 --- a/pytorch/utils.py +++ b/pytorch/utils.py @@ -12,6 +12,7 @@ from .metrics.seg_metrics import PixelAccuracyMetric, MeanIoUMetric from .metrics.det_metrics import CocoDetMApMetric from .metrics.hpe_metrics import CocoHpeOksApMetric +from .metrics.asr_metrics import WER def prepare_pt_context(num_gpus, @@ -247,6 +248,8 @@ def get_metric(metric_name, metric_extra_kwargs): return CocoDetMApMetric(**metric_extra_kwargs) elif metric_name == "CocoHpeOksApMetric": return CocoHpeOksApMetric(**metric_extra_kwargs) + elif metric_name == "WER": + return WER(**metric_extra_kwargs) else: raise Exception("Wrong metric name: {}".format(metric_name))