forked from mozilla/DeepSpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.py
120 lines (93 loc) · 4.46 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import argparse
import numpy as np
import shlex
import subprocess
import sys
import wave
from deepspeech import Model, printVersions
from timeit import default_timer as timer
try:
from shhlex import quote
except ImportError:
from pipes import quote
# These constants control the beam search decoder
# Beam width used in the CTC decoder when building candidate transcriptions
BEAM_WIDTH = 500
# The alpha hyperparameter of the CTC decoder. Language Model weight
LM_ALPHA = 0.75
# The beta hyperparameter of the CTC decoder. Word insertion bonus.
LM_BETA = 1.85
# These constants are tied to the shape of the graph used (changing them changes
# the geometry of the first layer), so make sure you use the same constants that
# were used during training
# Number of MFCC features to use
N_FEATURES = 26
# Size of the context window used for producing timesteps in the input vector
N_CONTEXT = 9
def convert_samplerate(audio_path):
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate 16000 --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path))
try:
output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE)
except subprocess.CalledProcessError as e:
raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr))
except OSError as e:
raise OSError(e.errno, 'SoX not found, use 16kHz files or install it: {}'.format(e.strerror))
return 16000, np.frombuffer(output, np.int16)
def metadata_to_string(metadata):
return ''.join(item.character for item in metadata.items)
class VersionAction(argparse.Action):
def __init__(self, *args, **kwargs):
super(VersionAction, self).__init__(nargs=0, *args, **kwargs)
def __call__(self, *args, **kwargs):
printVersions()
exit(0)
def main():
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
parser.add_argument('--model', required=True,
help='Path to the model (protocol buffer binary file)')
parser.add_argument('--alphabet', required=True,
help='Path to the configuration file specifying the alphabet used by the network')
parser.add_argument('--lm', nargs='?',
help='Path to the language model binary file')
parser.add_argument('--trie', nargs='?',
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--audio', required=True,
help='Path to the audio file to run (WAV format)')
parser.add_argument('--version', action=VersionAction,
help='Print version and exits')
parser.add_argument('--extended', required=False, action='store_true',
help='Output string from extended metadata')
args = parser.parse_args()
print('Loading model from file {}'.format(args.model), file=sys.stderr)
model_load_start = timer()
ds = Model(args.model, N_FEATURES, N_CONTEXT, args.alphabet, BEAM_WIDTH)
model_load_end = timer() - model_load_start
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
if args.lm and args.trie:
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr)
lm_load_start = timer()
ds.enableDecoderWithLM(args.alphabet, args.lm, args.trie, LM_ALPHA, LM_BETA)
lm_load_end = timer() - lm_load_start
print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr)
fin = wave.open(args.audio, 'rb')
fs = fin.getframerate()
if fs != 16000:
print('Warning: original sample rate ({}) is different than 16kHz. Resampling might produce erratic speech recognition.'.format(fs), file=sys.stderr)
fs, audio = convert_samplerate(args.audio)
else:
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
audio_length = fin.getnframes() * (1/16000)
fin.close()
print('Running inference.', file=sys.stderr)
inference_start = timer()
if args.extended:
print(metadata_to_string(ds.sttWithMetadata(audio, fs)))
else:
print(ds.stt(audio, fs))
inference_end = timer() - inference_start
print('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length), file=sys.stderr)
if __name__ == '__main__':
main()