Skip to content

Commit fca8c18

Browse files
authored
[WIP] Refactoring streaming transducer (#78)
* Refactoring streaming transducer * Refactor code for emformer * Refactor conv emformer * Rename decode.py to stream.py * Refactor offline code * Refactor offline_asr * Fix code according to review * Fix flake8 file and black formating * Fix offline beam search * Add keyword arguments
1 parent f723283 commit fca8c18

File tree

14 files changed

+1130
-891
lines changed

14 files changed

+1130
-891
lines changed

.flake8

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ exclude =
88
./cmake,
99
./triton,
1010
./sherpa/python/sherpa/__init__.py,
11-
./sherpa/python/sherpa/decode.py
11+
./sherpa/python/sherpa/decode.py,
12+
./sherpa/python/bin
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
from typing import List
2+
3+
import k2
4+
import torch
5+
from stream import Stream, stack_states, unstack_states
6+
7+
from sherpa import fast_beam_search_one_best, streaming_greedy_search
8+
9+
10+
class FastBeamSearch:
11+
def __init__(
12+
self,
13+
vocab_size: int,
14+
context_size: int,
15+
beam: int,
16+
max_states: int,
17+
max_contexts: int,
18+
device: torch.device,
19+
):
20+
"""
21+
Args:
22+
vocab_size:
23+
Vocabularize of the BPE
24+
context_size:
25+
Context size of the RNN-T decoder model.
26+
beam:
27+
The beam for fast_beam_search decoding.
28+
max_states:
29+
The max_states for fast_beam_search decoding.
30+
max_contexts:
31+
The max_contexts for fast_beam_search decoding.
32+
device:
33+
Device on which the computation will occur
34+
"""
35+
self.rnnt_decoding_config = k2.RnntDecodingConfig(
36+
vocab_size=vocab_size,
37+
decoder_history_len=context_size,
38+
beam=beam,
39+
max_states=max_states,
40+
max_contexts=max_contexts,
41+
)
42+
self.decoding_graph = k2.trivial_graph(vocab_size - 1, device)
43+
self.device = device
44+
self.context_size = context_size
45+
46+
def init_stream(self, stream: Stream):
47+
"""
48+
Attributes to add to each stream
49+
"""
50+
stream.rnnt_decoding_stream = k2.RnntDecodingStream(self.decoding_graph)
51+
stream.hyp = []
52+
53+
@torch.no_grad()
54+
def process(
55+
self,
56+
server: "StreamingServer",
57+
stream_list: List[Stream],
58+
) -> None:
59+
"""Run the model on the given stream list and do search with fast_beam_search
60+
method.
61+
Args:
62+
server:
63+
An instance of `StreamingServer`.
64+
stream_list:
65+
A list of streams to be processed. It is changed in-place.
66+
That is, the attribute `states` and `hyp` are
67+
updated in-place.
68+
"""
69+
model = server.model
70+
device = model.device
71+
# Note: chunk_length is in frames before subsampling
72+
chunk_length = server.chunk_length
73+
batch_size = len(stream_list)
74+
chunk_length_pad = server.chunk_length_pad
75+
state_list, feature_list = [], []
76+
processed_frames_list, rnnt_decoding_streams_list = [], []
77+
78+
rnnt_decoding_config = self.rnnt_decoding_config
79+
for s in stream_list:
80+
rnnt_decoding_streams_list.append(s.rnnt_decoding_stream)
81+
state_list.append(s.states)
82+
processed_frames_list.append(s.processed_frames)
83+
f = s.features[:chunk_length_pad]
84+
s.features = s.features[chunk_length:]
85+
s.processed_frames += chunk_length
86+
87+
b = torch.cat(f, dim=0)
88+
feature_list.append(b)
89+
90+
features = torch.stack(feature_list, dim=0).to(device)
91+
92+
states = stack_states(state_list)
93+
94+
features_length = torch.full(
95+
(batch_size,),
96+
fill_value=features.size(1),
97+
device=device,
98+
dtype=torch.int64,
99+
)
100+
101+
num_processed_frames = torch.tensor(
102+
processed_frames_list, device=device
103+
)
104+
105+
(
106+
encoder_out,
107+
encoder_out_lens,
108+
next_states,
109+
) = model.encoder_streaming_forward(
110+
features=features,
111+
features_length=features_length,
112+
num_processed_frames=num_processed_frames,
113+
states=states,
114+
)
115+
116+
processed_lens = (num_processed_frames >> 2) + encoder_out_lens
117+
next_hyp_list = fast_beam_search_one_best(
118+
model=model,
119+
encoder_out=encoder_out,
120+
processed_lens=processed_lens,
121+
rnnt_decoding_config=rnnt_decoding_config,
122+
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
123+
)
124+
125+
next_state_list = unstack_states(next_states)
126+
for i, s in enumerate(stream_list):
127+
s.states = next_state_list[i]
128+
s.hyp = next_hyp_list[i]
129+
130+
def get_texts(self, stream: Stream):
131+
"""
132+
Return text after decoding
133+
Args:
134+
stream:
135+
Stream to be processed.
136+
"""
137+
return self.sp.decode(stream.hyp)
138+
139+
140+
class GreedySearch:
141+
def __init__(self, model: "RnntConvEmformerModel", device: torch.device):
142+
"""
143+
Args:
144+
model:
145+
RNN-T model decoder model
146+
device:
147+
Device on which the computation will occur
148+
"""
149+
150+
self.blank_id = model.blank_id
151+
self.context_size = model.context_size
152+
self.device = device
153+
154+
decoder_input = torch.tensor(
155+
[[self.blank_id] * self.context_size],
156+
device=self.device,
157+
dtype=torch.int64,
158+
)
159+
initial_decoder_out = model.decoder_forward(decoder_input)
160+
self.initial_decoder_out = model.forward_decoder_proj(
161+
initial_decoder_out.squeeze(1)
162+
)
163+
164+
def init_stream(self, stream: Stream):
165+
"""
166+
Attributes to add to each stream
167+
"""
168+
stream.decoder_out = self.initial_decoder_out
169+
stream.hyp = [self.blank_id] * self.context_size
170+
171+
@torch.no_grad()
172+
def process(
173+
self,
174+
server: "StreamingServer",
175+
stream_list: List[Stream],
176+
) -> None:
177+
"""Run the model on the given stream list and do search with greedy_search
178+
method.
179+
Args:
180+
server:
181+
An instance of `StreamingServer`.
182+
stream_list:
183+
A list of streams to be processed. It is changed in-place.
184+
That is, the attribute `states` and `hyp` are
185+
updated in-place.
186+
"""
187+
model = server.model
188+
device = model.device
189+
# Note: chunk_length is in frames before subsampling
190+
chunk_length = server.chunk_length
191+
batch_size = len(stream_list)
192+
chunk_length_pad = server.chunk_length_pad
193+
state_list, feature_list = [], []
194+
decoder_out_list, hyp_list = [], []
195+
processed_frames_list = []
196+
197+
for s in stream_list:
198+
decoder_out_list.append(s.decoder_out)
199+
hyp_list.append(s.hyp)
200+
state_list.append(s.states)
201+
processed_frames_list.append(s.processed_frames)
202+
f = s.features[:chunk_length_pad]
203+
s.features = s.features[chunk_length:]
204+
s.processed_frames += chunk_length
205+
206+
b = torch.cat(f, dim=0)
207+
feature_list.append(b)
208+
209+
features = torch.stack(feature_list, dim=0).to(device)
210+
states = stack_states(state_list)
211+
decoder_out = torch.cat(decoder_out_list, dim=0)
212+
213+
features_length = torch.full(
214+
(batch_size,),
215+
fill_value=features.size(1),
216+
device=device,
217+
dtype=torch.int64,
218+
)
219+
220+
num_processed_frames = torch.tensor(
221+
processed_frames_list, device=device
222+
)
223+
224+
(
225+
encoder_out,
226+
encoder_out_lens,
227+
next_states,
228+
) = model.encoder_streaming_forward(
229+
features=features,
230+
features_length=features_length,
231+
num_processed_frames=num_processed_frames,
232+
states=states,
233+
)
234+
235+
# Note: It does not return the next_encoder_out_len since
236+
# there are no paddings for streaming ASR. Each stream
237+
# has the same input number of frames, i.e., server.chunk_length.
238+
next_decoder_out, next_hyp_list = streaming_greedy_search(
239+
model=model,
240+
encoder_out=encoder_out,
241+
decoder_out=decoder_out,
242+
hyps=hyp_list,
243+
)
244+
245+
next_decoder_out_list = next_decoder_out.split(1)
246+
247+
next_state_list = unstack_states(next_states)
248+
for i, s in enumerate(stream_list):
249+
s.states = next_state_list[i]
250+
s.decoder_out = next_decoder_out_list[i]
251+
s.hyp = next_hyp_list[i]
252+
253+
def get_texts(self, stream: Stream):
254+
"""
255+
Return text after decoding
256+
Args:
257+
stream:
258+
Stream to be processed.
259+
"""
260+
return self.sp.decode(stream.hyp[self.context_size :])

sherpa/bin/conv_emformer_transducer_stateless2/decode.py renamed to sherpa/bin/conv_emformer_transducer_stateless2/stream.py

+1-34
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
# limitations under the License.
1616

1717
import math
18-
from typing import List, Optional, Tuple
18+
from typing import List, Tuple
1919

20-
import k2
2120
import torch
2221
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
2322

@@ -154,29 +153,15 @@ class Stream(object):
154153
def __init__(
155154
self,
156155
context_size: int,
157-
blank_id: int,
158156
initial_states: List[List[torch.Tensor]],
159-
decoding_method: str = "greedy_search",
160-
decoding_graph: Optional[k2.Fsa] = None,
161-
decoder_out: Optional[torch.Tensor] = None,
162157
) -> None:
163158
"""
164159
Args:
165160
context_size:
166161
Context size of the RNN-T decoder model.
167-
blank_id:
168-
Blank token ID of the BPE model.
169162
initial_states:
170163
The initial states of the Emformer model. Note that the state
171164
does not contain the batch dimension.
172-
decoding_method:
173-
The decoding method to use, currently, only greedy_search and
174-
fast_beam_search are supported.
175-
decoding_graph:
176-
The Fsa based decoding graph for fast_beam_search.
177-
decoder_out:
178-
The initial decoder out corresponding to the decoder input
179-
`[blank_id]*context_size`
180165
"""
181166
self.feature_extractor = _create_streaming_feature_extractor()
182167
# It contains a list of 2-D tensors representing the feature frames.
@@ -185,26 +170,8 @@ def __init__(
185170
self.num_fetched_frames = 0
186171

187172
self.states = initial_states
188-
self.decoding_graph = decoding_graph
189-
190-
if decoding_method == "fast_beam_search":
191-
assert decoding_graph is not None
192-
self.rnnt_decoding_stream = k2.RnntDecodingStream(decoding_graph)
193-
self.hyp = []
194-
elif decoding_method == "greedy_search":
195-
assert decoder_out is not None
196-
self.decoder_out = decoder_out
197-
self.hyp = [blank_id] * context_size
198-
else:
199-
# fmt: off
200-
raise ValueError(
201-
f"Decoding method {decoding_method} is not supported."
202-
)
203-
# fmt: on
204-
205173
self.processed_frames = 0
206174
self.context_size = context_size
207-
self.hyp = [blank_id] * context_size
208175
self.log_eps = math.log(1e-10)
209176

210177
def accept_waveform(

0 commit comments

Comments
 (0)