|
1 | | -from scipy.io import wavfile |
| 1 | +import argparse |
| 2 | +import numpy as np |
| 3 | +import h5py |
| 4 | +import scipy.io.wavfile |
| 5 | +import python_speech_features |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | + |
| 9 | +def load_model(model_weights, batch_norm_eps = 0.001, dtype = torch.float32): |
| 10 | + def conv_block(kernel_size, num_channels, stride = 1, dilation = 1, repeat = 1, padding = 0): |
| 11 | + modules = [] |
| 12 | + for i in range(repeat): |
| 13 | + conv = nn.Conv1d(num_channels[0] if i == 0 else num_channels[1], num_channels[1], kernel_size = kernel_size, stride = stride, dilation = dilation, padding = padding) |
| 14 | + modules.append(conv) |
| 15 | + modules.append(nn.Hardtanh(0, 20, inplace = True)) |
| 16 | + return nn.Sequential(*modules) |
| 17 | + |
| 18 | + model = nn.Sequential( |
| 19 | + conv_block(kernel_size = 11, num_channels = (64, 256), stride = 2, padding = 5), |
| 20 | + conv_block(kernel_size = 11, num_channels = (256, 256), repeat = 3, padding = 5), |
| 21 | + conv_block(kernel_size = 13, num_channels = (256, 384), repeat = 3, padding = 6), |
| 22 | + conv_block(kernel_size = 17, num_channels = (384, 512), repeat = 3, padding = 8), |
| 23 | + conv_block(kernel_size = 21, num_channels = (512, 640), repeat = 3, padding = 10), |
| 24 | + conv_block(kernel_size = 25, num_channels = (640, 768), repeat = 3, padding = 12), |
| 25 | + conv_block(kernel_size = 29, num_channels = (768, 896), repeat = 1, padding = 28, dilation = 2), |
| 26 | + conv_block(kernel_size = 1, num_channels = (896, 1024), repeat = 1), |
| 27 | + nn.Conv1d(1024, 29, 1) |
| 28 | + ) |
| 29 | + |
| 30 | + state_dict = {} |
| 31 | + with h5py.File(model_weights) as h: |
| 32 | + to_tensor = lambda path: torch.from_numpy(np.asarray(h[path])).to(dtype) |
| 33 | + for param_name in model.state_dict(): |
| 34 | + ij = [int(c) for c in param_name if c.isdigit()] |
| 35 | + if len(ij) > 1: |
| 36 | + kernel, moving_mean, moving_variance, beta, gamma = [to_tensor(f'ForwardPass/w2l_encoder/conv{1 + ij[0]}{1 + ij[1] // 2}/{suffix}') for suffix in ['kernel', '/bn/moving_mean', '/bn/moving_variance', '/bn/beta', '/bn/gamma']] |
| 37 | + factor = gamma * (moving_variance + batch_norm_eps).rsqrt() |
| 38 | + kernel *= factor |
| 39 | + bias = beta - moving_mean * factor |
| 40 | + else: |
| 41 | + kernel, bias = [to_tensor(f'ForwardPass/fully_connected_ctc_decoder/fully_connected/{suffix}') for suffix in ['kernel', 'bias']] |
| 42 | + kernel.unsqueeze_(0) |
| 43 | + state_dict[param_name] = kernel.permute(2, 1, 0) if 'weight' in param_name else bias |
| 44 | + model.load_state_dict(state_dict) |
| 45 | + return model |
| 46 | + |
| 47 | +if __name__ == '__main__': |
| 48 | + parser = argparse.ArgumentParser() |
| 49 | + parser.add_argument('-i', '--input_path', default = 'test.wav') |
| 50 | + parser.add_argument('--weights', default = 'w2l_plus_large_mp.h5') |
| 51 | + args = parser.parse_args() |
| 52 | + |
| 53 | + dtype = torch.float32 |
| 54 | + |
| 55 | + sample_rate, signal = scipy.io.wavfile.read(args.input_path) |
| 56 | + features = torch.from_numpy(python_speech_features.logfbank(signal=signal, |
| 57 | + samplerate=sample_rate, |
| 58 | + winlen=20e-3, |
| 59 | + winstep=10e-3, |
| 60 | + nfilt=64, |
| 61 | + nfft=512, |
| 62 | + lowfreq=0, highfreq=sample_rate/2, |
| 63 | + preemph=0.97)).to(dtype) |
| 64 | + |
| 65 | + batch = features.t().unsqueeze(0) |
| 66 | + model = load_model(args.weights, dtype = dtype) |
| 67 | + scores = model(batch).squeeze(0) |
| 68 | + |
| 69 | + decoded_greedy = scores.argmax(dim = 0).tolist() |
| 70 | + decoded_text = ''.join({0 : ' ', 27 : "'", 28 : '|'}.get(c, chr(c - 1 + ord('a'))) for c in decoded_greedy) |
| 71 | + postproc_text = ''.join(c for i, c in enumerate(decoded_text) if i == 0 or c != decoded_text[i - 1]).replace('|', '') |
| 72 | + |
| 73 | + print(postproc_text) |
0 commit comments