Skip to content

Commit

Permalink
[WIP] Refactoring streaming transducer (#78)
Browse files Browse the repository at this point in the history
* Refactoring streaming transducer

* Refactor code for emformer

* Refactor conv emformer

* Rename decode.py to stream.py

* Refactor offline code

* Refactor offline_asr

* Fix code according to review

* Fix flake8 file and black formating

* Fix offline beam search

* Add keyword arguments
  • Loading branch information
ezerhouni authored Jul 23, 2022
1 parent f723283 commit fca8c18
Show file tree
Hide file tree
Showing 14 changed files with 1,130 additions and 891 deletions.
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ exclude =
./cmake,
./triton,
./sherpa/python/sherpa/__init__.py,
./sherpa/python/sherpa/decode.py
./sherpa/python/sherpa/decode.py,
./sherpa/python/bin
260 changes: 260 additions & 0 deletions sherpa/bin/conv_emformer_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
from typing import List

import k2
import torch
from stream import Stream, stack_states, unstack_states

from sherpa import fast_beam_search_one_best, streaming_greedy_search


class FastBeamSearch:
def __init__(
self,
vocab_size: int,
context_size: int,
beam: int,
max_states: int,
max_contexts: int,
device: torch.device,
):
"""
Args:
vocab_size:
Vocabularize of the BPE
context_size:
Context size of the RNN-T decoder model.
beam:
The beam for fast_beam_search decoding.
max_states:
The max_states for fast_beam_search decoding.
max_contexts:
The max_contexts for fast_beam_search decoding.
device:
Device on which the computation will occur
"""
self.rnnt_decoding_config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
self.decoding_graph = k2.trivial_graph(vocab_size - 1, device)
self.device = device
self.context_size = context_size

def init_stream(self, stream: Stream):
"""
Attributes to add to each stream
"""
stream.rnnt_decoding_stream = k2.RnntDecodingStream(self.decoding_graph)
stream.hyp = []

@torch.no_grad()
def process(
self,
server: "StreamingServer",
stream_list: List[Stream],
) -> None:
"""Run the model on the given stream list and do search with fast_beam_search
method.
Args:
server:
An instance of `StreamingServer`.
stream_list:
A list of streams to be processed. It is changed in-place.
That is, the attribute `states` and `hyp` are
updated in-place.
"""
model = server.model
device = model.device
# Note: chunk_length is in frames before subsampling
chunk_length = server.chunk_length
batch_size = len(stream_list)
chunk_length_pad = server.chunk_length_pad
state_list, feature_list = [], []
processed_frames_list, rnnt_decoding_streams_list = [], []

rnnt_decoding_config = self.rnnt_decoding_config
for s in stream_list:
rnnt_decoding_streams_list.append(s.rnnt_decoding_stream)
state_list.append(s.states)
processed_frames_list.append(s.processed_frames)
f = s.features[:chunk_length_pad]
s.features = s.features[chunk_length:]
s.processed_frames += chunk_length

b = torch.cat(f, dim=0)
feature_list.append(b)

features = torch.stack(feature_list, dim=0).to(device)

states = stack_states(state_list)

features_length = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
dtype=torch.int64,
)

num_processed_frames = torch.tensor(
processed_frames_list, device=device
)

(
encoder_out,
encoder_out_lens,
next_states,
) = model.encoder_streaming_forward(
features=features,
features_length=features_length,
num_processed_frames=num_processed_frames,
states=states,
)

processed_lens = (num_processed_frames >> 2) + encoder_out_lens
next_hyp_list = fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
rnnt_decoding_config=rnnt_decoding_config,
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
)

next_state_list = unstack_states(next_states)
for i, s in enumerate(stream_list):
s.states = next_state_list[i]
s.hyp = next_hyp_list[i]

def get_texts(self, stream: Stream):
"""
Return text after decoding
Args:
stream:
Stream to be processed.
"""
return self.sp.decode(stream.hyp)


class GreedySearch:
def __init__(self, model: "RnntConvEmformerModel", device: torch.device):
"""
Args:
model:
RNN-T model decoder model
device:
Device on which the computation will occur
"""

self.blank_id = model.blank_id
self.context_size = model.context_size
self.device = device

decoder_input = torch.tensor(
[[self.blank_id] * self.context_size],
device=self.device,
dtype=torch.int64,
)
initial_decoder_out = model.decoder_forward(decoder_input)
self.initial_decoder_out = model.forward_decoder_proj(
initial_decoder_out.squeeze(1)
)

def init_stream(self, stream: Stream):
"""
Attributes to add to each stream
"""
stream.decoder_out = self.initial_decoder_out
stream.hyp = [self.blank_id] * self.context_size

@torch.no_grad()
def process(
self,
server: "StreamingServer",
stream_list: List[Stream],
) -> None:
"""Run the model on the given stream list and do search with greedy_search
method.
Args:
server:
An instance of `StreamingServer`.
stream_list:
A list of streams to be processed. It is changed in-place.
That is, the attribute `states` and `hyp` are
updated in-place.
"""
model = server.model
device = model.device
# Note: chunk_length is in frames before subsampling
chunk_length = server.chunk_length
batch_size = len(stream_list)
chunk_length_pad = server.chunk_length_pad
state_list, feature_list = [], []
decoder_out_list, hyp_list = [], []
processed_frames_list = []

for s in stream_list:
decoder_out_list.append(s.decoder_out)
hyp_list.append(s.hyp)
state_list.append(s.states)
processed_frames_list.append(s.processed_frames)
f = s.features[:chunk_length_pad]
s.features = s.features[chunk_length:]
s.processed_frames += chunk_length

b = torch.cat(f, dim=0)
feature_list.append(b)

features = torch.stack(feature_list, dim=0).to(device)
states = stack_states(state_list)
decoder_out = torch.cat(decoder_out_list, dim=0)

features_length = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
dtype=torch.int64,
)

num_processed_frames = torch.tensor(
processed_frames_list, device=device
)

(
encoder_out,
encoder_out_lens,
next_states,
) = model.encoder_streaming_forward(
features=features,
features_length=features_length,
num_processed_frames=num_processed_frames,
states=states,
)

# Note: It does not return the next_encoder_out_len since
# there are no paddings for streaming ASR. Each stream
# has the same input number of frames, i.e., server.chunk_length.
next_decoder_out, next_hyp_list = streaming_greedy_search(
model=model,
encoder_out=encoder_out,
decoder_out=decoder_out,
hyps=hyp_list,
)

next_decoder_out_list = next_decoder_out.split(1)

next_state_list = unstack_states(next_states)
for i, s in enumerate(stream_list):
s.states = next_state_list[i]
s.decoder_out = next_decoder_out_list[i]
s.hyp = next_hyp_list[i]

def get_texts(self, stream: Stream):
"""
Return text after decoding
Args:
stream:
Stream to be processed.
"""
return self.sp.decode(stream.hyp[self.context_size :])
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
# limitations under the License.

import math
from typing import List, Optional, Tuple
from typing import List, Tuple

import k2
import torch
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature

Expand Down Expand Up @@ -154,29 +153,15 @@ class Stream(object):
def __init__(
self,
context_size: int,
blank_id: int,
initial_states: List[List[torch.Tensor]],
decoding_method: str = "greedy_search",
decoding_graph: Optional[k2.Fsa] = None,
decoder_out: Optional[torch.Tensor] = None,
) -> None:
"""
Args:
context_size:
Context size of the RNN-T decoder model.
blank_id:
Blank token ID of the BPE model.
initial_states:
The initial states of the Emformer model. Note that the state
does not contain the batch dimension.
decoding_method:
The decoding method to use, currently, only greedy_search and
fast_beam_search are supported.
decoding_graph:
The Fsa based decoding graph for fast_beam_search.
decoder_out:
The initial decoder out corresponding to the decoder input
`[blank_id]*context_size`
"""
self.feature_extractor = _create_streaming_feature_extractor()
# It contains a list of 2-D tensors representing the feature frames.
Expand All @@ -185,26 +170,8 @@ def __init__(
self.num_fetched_frames = 0

self.states = initial_states
self.decoding_graph = decoding_graph

if decoding_method == "fast_beam_search":
assert decoding_graph is not None
self.rnnt_decoding_stream = k2.RnntDecodingStream(decoding_graph)
self.hyp = []
elif decoding_method == "greedy_search":
assert decoder_out is not None
self.decoder_out = decoder_out
self.hyp = [blank_id] * context_size
else:
# fmt: off
raise ValueError(
f"Decoding method {decoding_method} is not supported."
)
# fmt: on

self.processed_frames = 0
self.context_size = context_size
self.hyp = [blank_id] * context_size
self.log_eps = math.log(1e-10)

def accept_waveform(
Expand Down
Loading

0 comments on commit fca8c18

Please sign in to comment.