Skip to content

Commit b666045

Browse files
committed
...
1 parent 7ff3277 commit b666045

File tree

1 file changed

+73
-1
lines changed

1 file changed

+73
-1
lines changed

speech2text.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,73 @@
1-
from scipy.io import wavfile
1+
import argparse
2+
import numpy as np
3+
import h5py
4+
import scipy.io.wavfile
5+
import python_speech_features
6+
import torch
7+
import torch.nn as nn
8+
9+
def load_model(model_weights, batch_norm_eps = 0.001, dtype = torch.float32):
10+
def conv_block(kernel_size, num_channels, stride = 1, dilation = 1, repeat = 1, padding = 0):
11+
modules = []
12+
for i in range(repeat):
13+
conv = nn.Conv1d(num_channels[0] if i == 0 else num_channels[1], num_channels[1], kernel_size = kernel_size, stride = stride, dilation = dilation, padding = padding)
14+
modules.append(conv)
15+
modules.append(nn.Hardtanh(0, 20, inplace = True))
16+
return nn.Sequential(*modules)
17+
18+
model = nn.Sequential(
19+
conv_block(kernel_size = 11, num_channels = (64, 256), stride = 2, padding = 5),
20+
conv_block(kernel_size = 11, num_channels = (256, 256), repeat = 3, padding = 5),
21+
conv_block(kernel_size = 13, num_channels = (256, 384), repeat = 3, padding = 6),
22+
conv_block(kernel_size = 17, num_channels = (384, 512), repeat = 3, padding = 8),
23+
conv_block(kernel_size = 21, num_channels = (512, 640), repeat = 3, padding = 10),
24+
conv_block(kernel_size = 25, num_channels = (640, 768), repeat = 3, padding = 12),
25+
conv_block(kernel_size = 29, num_channels = (768, 896), repeat = 1, padding = 28, dilation = 2),
26+
conv_block(kernel_size = 1, num_channels = (896, 1024), repeat = 1),
27+
nn.Conv1d(1024, 29, 1)
28+
)
29+
30+
state_dict = {}
31+
with h5py.File(model_weights) as h:
32+
to_tensor = lambda path: torch.from_numpy(np.asarray(h[path])).to(dtype)
33+
for param_name in model.state_dict():
34+
ij = [int(c) for c in param_name if c.isdigit()]
35+
if len(ij) > 1:
36+
kernel, moving_mean, moving_variance, beta, gamma = [to_tensor(f'ForwardPass/w2l_encoder/conv{1 + ij[0]}{1 + ij[1] // 2}/{suffix}') for suffix in ['kernel', '/bn/moving_mean', '/bn/moving_variance', '/bn/beta', '/bn/gamma']]
37+
factor = gamma * (moving_variance + batch_norm_eps).rsqrt()
38+
kernel *= factor
39+
bias = beta - moving_mean * factor
40+
else:
41+
kernel, bias = [to_tensor(f'ForwardPass/fully_connected_ctc_decoder/fully_connected/{suffix}') for suffix in ['kernel', 'bias']]
42+
kernel.unsqueeze_(0)
43+
state_dict[param_name] = kernel.permute(2, 1, 0) if 'weight' in param_name else bias
44+
model.load_state_dict(state_dict)
45+
return model
46+
47+
if __name__ == '__main__':
48+
parser = argparse.ArgumentParser()
49+
parser.add_argument('-i', '--input_path', default = 'test.wav')
50+
parser.add_argument('--weights', default = 'w2l_plus_large_mp.h5')
51+
args = parser.parse_args()
52+
53+
dtype = torch.float32
54+
55+
sample_rate, signal = scipy.io.wavfile.read(args.input_path)
56+
features = torch.from_numpy(python_speech_features.logfbank(signal=signal,
57+
samplerate=sample_rate,
58+
winlen=20e-3,
59+
winstep=10e-3,
60+
nfilt=64,
61+
nfft=512,
62+
lowfreq=0, highfreq=sample_rate/2,
63+
preemph=0.97)).to(dtype)
64+
65+
batch = features.t().unsqueeze(0)
66+
model = load_model(args.weights, dtype = dtype)
67+
scores = model(batch).squeeze(0)
68+
69+
decoded_greedy = scores.argmax(dim = 0).tolist()
70+
decoded_text = ''.join({0 : ' ', 27 : "'", 28 : '|'}.get(c, chr(c - 1 + ord('a'))) for c in decoded_greedy)
71+
postproc_text = ''.join(c for i, c in enumerate(decoded_text) if i == 0 or c != decoded_text[i - 1]).replace('|', '')
72+
73+
print(postproc_text)

0 commit comments

Comments
 (0)