-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathasr_test.py
133 lines (114 loc) · 5.43 KB
/
asr_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import torch
import numpy as np
import re
from realtime_chatbot.asr_handler import ASRHandlerMultiprocessing, ASRConfig
from datetime import datetime
class ASRTestGradioInteface:
def __init__(self):
self.asr_handler = ASRHandlerMultiprocessing(
device=torch.device("cuda:1"),
wait_until_running=False,
output_debug_audio=True
)
self.asr_handler.wait_until_running()
def update_transcription(self, transcription, new_text, partial_pos):
if new_text:
# First, clear out the previous partial segment (if exists)
if partial_pos > -1:
transcription = transcription[:partial_pos]
partial_pos = -1
# Next, add the new segments to the transcription,
# discarding intermediate partial segments.
new_segments = re.split(" (?=[~*])", new_text)
for i, seg in enumerate(new_segments):
if len(seg) > 1 and (seg.startswith("*") or i == len(new_segments)-1):
if seg.startswith("~"):
partial_pos = len(transcription)
if len(transcription) > 0:
transcription += " "
transcription += seg[1:]
return transcription, partial_pos
def execute(self, state, audio, reset, collect, simulate_load, asr_max_buffer_size, asr_model_size,
asr_logprob_threshold, asr_no_speech_threshold, asr_lang):
# queue up configs in case any changes were made.
asr_config = ASRConfig(model_size=asr_model_size, lang=asr_lang, logprob_threshold=asr_logprob_threshold,
no_speech_threshold=asr_no_speech_threshold, max_buffer_size=asr_max_buffer_size)
if asr_config != state["asr_config"]:
state["asr_config"] = asr_config
self.asr_handler.queue_config(asr_config)
# If there is collected audio and collect is switched off, output it
collected_audio = state["collected_audio"]
collected_audio_concat = None
if not collect and len(collected_audio) > 0:
collected_audio_concat = (collected_audio[0][0], np.concatenate([ca[1] for ca in collected_audio]))
collected_audio.clear()
# If there is audio input, queue it up for ASR.
if audio is not None:
self.asr_handler.queue_input(audio)
# If there is ASR debug audio output, collect it
asr_debug_audio = self.asr_handler.next_debug_audio()
if collect and asr_debug_audio is not None:
collected_audio.append(asr_debug_audio)
# If there is ASR output, append to display
transcription = state["transcription"]
partial_pos = state["partial_pos"]
if reset:
transcription = ""
partial_pos = -1
new_text = self.asr_handler.next_output()
transcription, partial_pos = self.update_transcription(transcription, new_text, partial_pos)
state["transcription"] = transcription
state["partial_pos"] = partial_pos
if simulate_load:
then = datetime.now()
while((datetime.now() - then).total_seconds() < 1):
pass
return state, transcription, collected_audio_concat
def launch(self):
asr_model_size = gr.Dropdown(label="ASR Model size", choices=self.asr_handler.available_model_sizes, value='small.en')
asr_max_buffer_size_slider = gr.inputs.Slider(minimum=1, maximum=10, default=5, step=1, label="ASR max buffer size")
asr_logprob_threshold_slider = gr.inputs.Slider(minimum=-3.0, maximum=0.0, default=-0.7, step=0.05, label="ASR Log prob threshold")
asr_no_speech_threshold_slider = gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.6, step=0.05, label="ASR No speech threshold")
asr_lang_dropdown = gr.inputs.Dropdown(choices=self.asr_handler.available_languages, label="ASR Language",
default="English", type="value")
if asr_lang_dropdown==self.asr_handler.AUTO_DETECT_LANG:
asr_lang_dropdown=None
reset_button = gr.Checkbox(value=True, label="Reset")
collect_button = gr.Checkbox(value=True, label="Collect Audio")
simulate_load_button = gr.Checkbox(value=False, label="Simulate Load")
state = gr.State({
"transcription": "",
"partial_pos": -1,
"asr_config": ASRConfig(),
"collected_audio": []
})
output_textbox = gr.Textbox(label="Output")
interface = gr.Interface(
fn=self.execute,
inputs=[
state,
gr.Audio(source="microphone", streaming=True, label="ASR Input"),
reset_button,
collect_button,
simulate_load_button,
asr_max_buffer_size_slider,
asr_model_size,
asr_logprob_threshold_slider,
asr_no_speech_threshold_slider,
asr_lang_dropdown
],
outputs=[
state,
output_textbox,
gr.Audio(label="Collected Audio")
],
live=True,
allow_flagging='never',
title="ASR Test",
description="ASR Test"
)
interface.launch()
if __name__ == "__main__":
interface = ASRTestGradioInteface()
interface.launch()