diff --git a/.flake8 b/.flake8 index 670bbcf9..b3ddd892 100644 --- a/.flake8 +++ b/.flake8 @@ -8,4 +8,5 @@ exclude = ./cmake, ./triton, ./sherpa/python/sherpa/__init__.py, - ./sherpa/python/sherpa/decode.py + ./sherpa/python/sherpa/decode.py, + ./sherpa/python/bin diff --git a/sherpa/bin/conv_emformer_transducer_stateless2/beam_search.py b/sherpa/bin/conv_emformer_transducer_stateless2/beam_search.py new file mode 100644 index 00000000..f4509598 --- /dev/null +++ b/sherpa/bin/conv_emformer_transducer_stateless2/beam_search.py @@ -0,0 +1,260 @@ +from typing import List + +import k2 +import torch +from stream import Stream, stack_states, unstack_states + +from sherpa import fast_beam_search_one_best, streaming_greedy_search + + +class FastBeamSearch: + def __init__( + self, + vocab_size: int, + context_size: int, + beam: int, + max_states: int, + max_contexts: int, + device: torch.device, + ): + """ + Args: + vocab_size: + Vocabularize of the BPE + context_size: + Context size of the RNN-T decoder model. + beam: + The beam for fast_beam_search decoding. + max_states: + The max_states for fast_beam_search decoding. + max_contexts: + The max_contexts for fast_beam_search decoding. + device: + Device on which the computation will occur + """ + self.rnnt_decoding_config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + self.decoding_graph = k2.trivial_graph(vocab_size - 1, device) + self.device = device + self.context_size = context_size + + def init_stream(self, stream: Stream): + """ + Attributes to add to each stream + """ + stream.rnnt_decoding_stream = k2.RnntDecodingStream(self.decoding_graph) + stream.hyp = [] + + @torch.no_grad() + def process( + self, + server: "StreamingServer", + stream_list: List[Stream], + ) -> None: + """Run the model on the given stream list and do search with fast_beam_search + method. + Args: + server: + An instance of `StreamingServer`. + stream_list: + A list of streams to be processed. It is changed in-place. + That is, the attribute `states` and `hyp` are + updated in-place. + """ + model = server.model + device = model.device + # Note: chunk_length is in frames before subsampling + chunk_length = server.chunk_length + batch_size = len(stream_list) + chunk_length_pad = server.chunk_length_pad + state_list, feature_list = [], [] + processed_frames_list, rnnt_decoding_streams_list = [], [] + + rnnt_decoding_config = self.rnnt_decoding_config + for s in stream_list: + rnnt_decoding_streams_list.append(s.rnnt_decoding_stream) + state_list.append(s.states) + processed_frames_list.append(s.processed_frames) + f = s.features[:chunk_length_pad] + s.features = s.features[chunk_length:] + s.processed_frames += chunk_length + + b = torch.cat(f, dim=0) + feature_list.append(b) + + features = torch.stack(feature_list, dim=0).to(device) + + states = stack_states(state_list) + + features_length = torch.full( + (batch_size,), + fill_value=features.size(1), + device=device, + dtype=torch.int64, + ) + + num_processed_frames = torch.tensor( + processed_frames_list, device=device + ) + + ( + encoder_out, + encoder_out_lens, + next_states, + ) = model.encoder_streaming_forward( + features=features, + features_length=features_length, + num_processed_frames=num_processed_frames, + states=states, + ) + + processed_lens = (num_processed_frames >> 2) + encoder_out_lens + next_hyp_list = fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + rnnt_decoding_config=rnnt_decoding_config, + rnnt_decoding_streams_list=rnnt_decoding_streams_list, + ) + + next_state_list = unstack_states(next_states) + for i, s in enumerate(stream_list): + s.states = next_state_list[i] + s.hyp = next_hyp_list[i] + + def get_texts(self, stream: Stream): + """ + Return text after decoding + Args: + stream: + Stream to be processed. + """ + return self.sp.decode(stream.hyp) + + +class GreedySearch: + def __init__(self, model: "RnntConvEmformerModel", device: torch.device): + """ + Args: + model: + RNN-T model decoder model + device: + Device on which the computation will occur + """ + + self.blank_id = model.blank_id + self.context_size = model.context_size + self.device = device + + decoder_input = torch.tensor( + [[self.blank_id] * self.context_size], + device=self.device, + dtype=torch.int64, + ) + initial_decoder_out = model.decoder_forward(decoder_input) + self.initial_decoder_out = model.forward_decoder_proj( + initial_decoder_out.squeeze(1) + ) + + def init_stream(self, stream: Stream): + """ + Attributes to add to each stream + """ + stream.decoder_out = self.initial_decoder_out + stream.hyp = [self.blank_id] * self.context_size + + @torch.no_grad() + def process( + self, + server: "StreamingServer", + stream_list: List[Stream], + ) -> None: + """Run the model on the given stream list and do search with greedy_search + method. + Args: + server: + An instance of `StreamingServer`. + stream_list: + A list of streams to be processed. It is changed in-place. + That is, the attribute `states` and `hyp` are + updated in-place. + """ + model = server.model + device = model.device + # Note: chunk_length is in frames before subsampling + chunk_length = server.chunk_length + batch_size = len(stream_list) + chunk_length_pad = server.chunk_length_pad + state_list, feature_list = [], [] + decoder_out_list, hyp_list = [], [] + processed_frames_list = [] + + for s in stream_list: + decoder_out_list.append(s.decoder_out) + hyp_list.append(s.hyp) + state_list.append(s.states) + processed_frames_list.append(s.processed_frames) + f = s.features[:chunk_length_pad] + s.features = s.features[chunk_length:] + s.processed_frames += chunk_length + + b = torch.cat(f, dim=0) + feature_list.append(b) + + features = torch.stack(feature_list, dim=0).to(device) + states = stack_states(state_list) + decoder_out = torch.cat(decoder_out_list, dim=0) + + features_length = torch.full( + (batch_size,), + fill_value=features.size(1), + device=device, + dtype=torch.int64, + ) + + num_processed_frames = torch.tensor( + processed_frames_list, device=device + ) + + ( + encoder_out, + encoder_out_lens, + next_states, + ) = model.encoder_streaming_forward( + features=features, + features_length=features_length, + num_processed_frames=num_processed_frames, + states=states, + ) + + # Note: It does not return the next_encoder_out_len since + # there are no paddings for streaming ASR. Each stream + # has the same input number of frames, i.e., server.chunk_length. + next_decoder_out, next_hyp_list = streaming_greedy_search( + model=model, + encoder_out=encoder_out, + decoder_out=decoder_out, + hyps=hyp_list, + ) + + next_decoder_out_list = next_decoder_out.split(1) + + next_state_list = unstack_states(next_states) + for i, s in enumerate(stream_list): + s.states = next_state_list[i] + s.decoder_out = next_decoder_out_list[i] + s.hyp = next_hyp_list[i] + + def get_texts(self, stream: Stream): + """ + Return text after decoding + Args: + stream: + Stream to be processed. + """ + return self.sp.decode(stream.hyp[self.context_size :]) diff --git a/sherpa/bin/conv_emformer_transducer_stateless2/decode.py b/sherpa/bin/conv_emformer_transducer_stateless2/stream.py similarity index 87% rename from sherpa/bin/conv_emformer_transducer_stateless2/decode.py rename to sherpa/bin/conv_emformer_transducer_stateless2/stream.py index 4d2e6520..fc1007ac 100644 --- a/sherpa/bin/conv_emformer_transducer_stateless2/decode.py +++ b/sherpa/bin/conv_emformer_transducer_stateless2/stream.py @@ -15,9 +15,8 @@ # limitations under the License. import math -from typing import List, Optional, Tuple +from typing import List, Tuple -import k2 import torch from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature @@ -154,29 +153,15 @@ class Stream(object): def __init__( self, context_size: int, - blank_id: int, initial_states: List[List[torch.Tensor]], - decoding_method: str = "greedy_search", - decoding_graph: Optional[k2.Fsa] = None, - decoder_out: Optional[torch.Tensor] = None, ) -> None: """ Args: context_size: Context size of the RNN-T decoder model. - blank_id: - Blank token ID of the BPE model. initial_states: The initial states of the Emformer model. Note that the state does not contain the batch dimension. - decoding_method: - The decoding method to use, currently, only greedy_search and - fast_beam_search are supported. - decoding_graph: - The Fsa based decoding graph for fast_beam_search. - decoder_out: - The initial decoder out corresponding to the decoder input - `[blank_id]*context_size` """ self.feature_extractor = _create_streaming_feature_extractor() # It contains a list of 2-D tensors representing the feature frames. @@ -185,26 +170,8 @@ def __init__( self.num_fetched_frames = 0 self.states = initial_states - self.decoding_graph = decoding_graph - - if decoding_method == "fast_beam_search": - assert decoding_graph is not None - self.rnnt_decoding_stream = k2.RnntDecodingStream(decoding_graph) - self.hyp = [] - elif decoding_method == "greedy_search": - assert decoder_out is not None - self.decoder_out = decoder_out - self.hyp = [blank_id] * context_size - else: - # fmt: off - raise ValueError( - f"Decoding method {decoding_method} is not supported." - ) - # fmt: on - self.processed_frames = 0 self.context_size = context_size - self.hyp = [blank_id] * context_size self.log_eps = math.log(1e-10) def accept_waveform( diff --git a/sherpa/bin/conv_emformer_transducer_stateless2/streaming_server.py b/sherpa/bin/conv_emformer_transducer_stateless2/streaming_server.py index 249a0472..770cd15b 100755 --- a/sherpa/bin/conv_emformer_transducer_stateless2/streaming_server.py +++ b/sherpa/bin/conv_emformer_transducer_stateless2/streaming_server.py @@ -34,20 +34,16 @@ import math import warnings from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional, Tuple +from typing import Optional, Tuple -import k2 import numpy as np import sentencepiece as spm import torch import websockets -from decode import Stream, stack_states, unstack_states +from beam_search import FastBeamSearch, GreedySearch +from stream import Stream -from sherpa import ( - RnntConvEmformerModel, - fast_beam_search_one_best, - streaming_greedy_search, -) +from sherpa import RnntConvEmformerModel def get_args(): @@ -175,114 +171,6 @@ def get_args(): return parser.parse_args() -@torch.no_grad() -def run_model_and_do_search( - server: "StreamingServer", - stream_list: List[Stream], -) -> None: - """Run the model on the given stream list and do greedy search. - Args: - server: - An instance of `StreamingServer`. - stream_list: - A list of streams to be processed. It is changed in-place. - That is, the attribute `states`, `decoder_out`, and `hyp` are - updated in-place. - """ - model = server.model - device = model.device - chunk_length = server.chunk_length - chunk_length_pad = server.chunk_length_pad - decoding_method = server.decoding_method - - batch_size = len(stream_list) - - state_list = [] - feature_list = [] - processed_frames_list = [] - if decoding_method == "greedy_search": - decoder_out_list = [] - hyp_list = [] - else: - rnnt_decoding_streams_list = [] - rnnt_decoding_config = server.rnnt_decoding_config - - for s in stream_list: - if decoding_method == "greedy_search": - decoder_out_list.append(s.decoder_out) - hyp_list.append(s.hyp) - elif decoding_method == "fast_beam_search": - rnnt_decoding_streams_list.append(s.rnnt_decoding_stream) - - state_list.append(s.states) - processed_frames_list.append(s.processed_frames) - f = s.features[:chunk_length_pad] - s.features = s.features[chunk_length:] - s.processed_frames += chunk_length - - b = torch.cat(f, dim=0) - feature_list.append(b) - - features = torch.stack(feature_list, dim=0).to(device) - - states = stack_states(state_list) - - if decoding_method == "greedy_search": - decoder_out = torch.cat(decoder_out_list, dim=0) - - features_length = torch.full( - (batch_size,), - fill_value=features.size(1), - device=device, - dtype=torch.int64, - ) - - num_processed_frames = torch.tensor(processed_frames_list, device=device) - - # fmt: off - ( - encoder_out, - encoder_out_lens, - next_states, - ) = model.encoder_streaming_forward( - features=features, - features_length=features_length, - num_processed_frames=num_processed_frames, - states=states, - ) - # fmt: on - - if decoding_method == "fast_beam_search": - processed_lens = (num_processed_frames >> 2) + encoder_out_lens - next_hyp_list = fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - rnnt_decoding_config=rnnt_decoding_config, - rnnt_decoding_streams_list=rnnt_decoding_streams_list, - ) - else: - # Note: It does not return the next_encoder_out_len since - # there are no paddings for streaming ASR. Each stream - # has the same input number of frames, i.e., server.chunk_length. - next_decoder_out, next_hyp_list = streaming_greedy_search( - model=model, - encoder_out=encoder_out, - decoder_out=decoder_out, - hyps=hyp_list, - ) - - if decoding_method == "greedy_search": - next_decoder_out_list = next_decoder_out.split(1) - - next_state_list = unstack_states(next_states) - for i, s in enumerate(stream_list): - s.states = next_state_list[i] - if decoding_method == "greedy_search": - s.decoder_out = next_decoder_out_list[i] - s.hyp = next_hyp_list[i] - - class StreamingServer(object): def __init__( self, @@ -312,8 +200,8 @@ def __init__( max_contexts: The max_contexts for fast_beam_search decoding. decoding_method: - The decoding method to use, can be either greedy_search or - fast_beam_search. + The decoding method to use, can be either greedy_search + or fast_beam_search. nn_pool_size: Number of threads for the thread pool that is responsible for neural network computation and decoding. @@ -356,33 +244,28 @@ def __init__( self.log_eps = math.log(1e-10) self.initial_states = self.model.get_encoder_init_states() - self.decoding_method = decoding_method - - assert self.decoding_method in ["greedy_search", "fast_beam_search"] - - self.initial_decoder_out = None - self.decoding_graph = None if decoding_method == "fast_beam_search": - self.rnnt_decoding_config = k2.RnntDecodingConfig( + self.beam_search = FastBeamSearch( vocab_size=self.vocab_size, - decoder_history_len=self.context_size, + context_size=self.context_size, beam=beam, max_states=max_states, max_contexts=max_contexts, - ) - self.decoding_graph = k2.trivial_graph(self.vocab_size - 1, device) - else: - decoder_input = torch.tensor( - [[self.blank_id] * self.context_size], device=device, - dtype=torch.int64, ) - initial_decoder_out = self.model.decoder_forward(decoder_input) - self.initial_decoder_out = self.model.forward_decoder_proj( - initial_decoder_out.squeeze(1) + elif decoding_method == "greedy_search": + self.beam_search = GreedySearch( + self.model, + device, + ) + else: + raise ValueError( + f"Decoding method {decoding_method} is not supported." ) + self.beam_search.sp = self.sp + self.nn_pool = ThreadPoolExecutor( max_workers=nn_pool_size, thread_name_prefix="nn", @@ -424,7 +307,7 @@ async def stream_consumer_task(self): loop = asyncio.get_running_loop() await loop.run_in_executor( self.nn_pool, - run_model_and_do_search, + self.beam_search.process, self, stream_list, ) @@ -520,13 +403,11 @@ async def handle_connection_impl( stream = Stream( context_size=self.context_size, - blank_id=self.blank_id, initial_states=self.initial_states, - decoding_method=self.decoding_method, - decoding_graph=self.decoding_graph, - decoder_out=self.initial_decoder_out, ) + self.beam_search.init_stream(stream) + while True: samples = await self.recv_audio_samples(socket) if samples is None: @@ -539,7 +420,7 @@ async def handle_connection_impl( while len(stream.features) > self.chunk_length_pad: await self.compute_and_decode(stream) await socket.send( - f"{self.sp.decode(stream.hyp[self.context_size:])}" + f"{self.beam_search.get_texts(stream)}" ) # noqa stream.input_finished() @@ -552,7 +433,7 @@ async def handle_connection_impl( await self.compute_and_decode(stream) stream.features = [] - result = self.sp.decode(stream.hyp[self.context_size :]) # noqa + result = self.beam_search.get_texts(stream) await socket.send(result) await socket.send("Done") @@ -646,8 +527,8 @@ def main(): """ if __name__ == "__main__": - # fmt: off + # fmt:off formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa - # fmt: on + # fmt:on logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/beam_search.py b/sherpa/bin/pruned_stateless_emformer_rnnt2/beam_search.py new file mode 100644 index 00000000..15ae1485 --- /dev/null +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/beam_search.py @@ -0,0 +1,340 @@ +from typing import List + +import k2 +import torch +from stream import Stream, stack_states, unstack_states + +from sherpa import ( + Hypotheses, + Hypothesis, + fast_beam_search_one_best, + streaming_greedy_search, + streaming_modified_beam_search, +) + + +class FastBeamSearch: + def __init__( + self, + vocab_size: int, + context_size: int, + beam: int, + max_states: int, + max_contexts: int, + device: torch.device, + ): + """ + Args: + vocab_size: + Vocabularize of the BPE + context_size: + Context size of the RNN-T decoder model. + beam: + The beam for fast_beam_search decoding. + max_states: + The max_states for fast_beam_search decoding. + max_contexts: + The max_contexts for fast_beam_search decoding. + device: + Device on which the computation will occur + """ + self.rnnt_decoding_config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + self.decoding_graph = k2.trivial_graph(vocab_size - 1, device) + self.device = device + self.context_size = context_size + + def init_stream(self, stream: Stream): + """ + Attributes to add to each stream + """ + stream.rnnt_decoding_stream = k2.RnntDecodingStream(self.decoding_graph) + stream.hyp = [] + + @torch.no_grad() + def process( + self, + server: "StreamingServer", + stream_list: List[Stream], + ) -> None: + """Run the model on the given stream list and do search with fast_beam_search + method. + Args: + server: + An instance of `StreamingServer`. + stream_list: + A list of streams to be processed. It is changed in-place. + That is, the attribute `states` and `hyp` are + updated in-place. + """ + model = server.model + device = model.device + # Note: chunk_length is in frames before subsampling + chunk_length = server.chunk_length + segment_length = server.segment_length + batch_size = len(stream_list) + + state_list, feature_list = [], [] + processed_frames_list, rnnt_decoding_streams_list = [], [] + + rnnt_decoding_config = self.rnnt_decoding_config + for s in stream_list: + rnnt_decoding_streams_list.append(s.rnnt_decoding_stream) + + state_list.append(s.states) + processed_frames_list.append(s.processed_frames) + f = s.features[:chunk_length] + s.features = s.features[segment_length:] + b = torch.cat(f, dim=0) + feature_list.append(b) + + features = torch.stack(feature_list, dim=0).to(device) + states = stack_states(state_list) + + features_length = torch.full( + (batch_size,), + fill_value=features.size(1), + device=device, + dtype=torch.int64, + ) + + processed_frames = torch.tensor(processed_frames_list, device=device) + + ( + encoder_out, + encoder_out_lens, + next_states, + ) = model.encoder_streaming_forward( # noqa + features=features, + features_length=features_length, + states=states, + ) + + processed_lens = processed_frames + encoder_out_lens + next_hyp_list = fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + rnnt_decoding_config=rnnt_decoding_config, + rnnt_decoding_streams_list=rnnt_decoding_streams_list, + ) + + next_state_list = unstack_states(next_states) + for i, s in enumerate(stream_list): + s.states = next_state_list[i] + s.processed_frames += encoder_out_lens[i] + s.hyp = next_hyp_list[i] + + def get_texts(self, stream: Stream): + """ + Return text after decoding + Args: + stream: + Stream to be processed. + """ + return self.sp.decode(stream.hyp) + + +class GreedySearch: + def __init__(self, model: "RnntEmformerModel", device: torch.device): + """ + Args: + model: + RNN-T model decoder model + device: + Device on which the computation will occur + """ + + self.blank_id = model.blank_id + self.context_size = model.context_size + self.device = device + + decoder_input = torch.tensor( + [[self.blank_id] * self.context_size], + device=self.device, + dtype=torch.int64, + ) + initial_decoder_out = model.decoder_forward(decoder_input) + self.initial_decoder_out = model.forward_decoder_proj( + initial_decoder_out.squeeze(1) + ) + + def init_stream(self, stream: Stream): + """ + Attributes to add to each stream + """ + stream.decoder_out = self.initial_decoder_out + stream.hyp = [self.blank_id] * self.context_size + + @torch.no_grad() + def process( + self, + server: "StreamingServer", + stream_list: List[Stream], + ) -> None: + """Run the model on the given stream list and do search with greedy_search + method. + Args: + server: + An instance of `StreamingServer`. + stream_list: + A list of streams to be processed. It is changed in-place. + That is, the attribute `states` and `hyp` are + updated in-place. + """ + model = server.model + device = model.device + # Note: chunk_length is in frames before subsampling + chunk_length = server.chunk_length + batch_size = len(stream_list) + segment_length = server.segment_length + + state_list, feature_list = [], [] + decoder_out_list, hyp_list = [], [] + + for s in stream_list: + decoder_out_list.append(s.decoder_out) + hyp_list.append(s.hyp) + + state_list.append(s.states) + + f = s.features[:chunk_length] + s.features = s.features[segment_length:] + b = torch.cat(f, dim=0) + feature_list.append(b) + + features = torch.stack(feature_list, dim=0).to(device) + states = stack_states(state_list) + + decoder_out = torch.cat(decoder_out_list, dim=0) + + features_length = torch.full( + (batch_size,), + fill_value=features.size(1), + device=device, + dtype=torch.int64, + ) + + ( + encoder_out, + _, + next_states, + ) = model.encoder_streaming_forward( # noqa + features=features, + features_length=features_length, + states=states, + ) + + # Note: It does not return the next_encoder_out_len since + # there are no paddings for streaming ASR. Each stream + # has the same input number of frames, i.e., server.chunk_length. + next_decoder_out, next_hyp_list = streaming_greedy_search( + model=model, + encoder_out=encoder_out, + decoder_out=decoder_out, + hyps=hyp_list, + ) + + next_decoder_out_list = next_decoder_out.split(1) + + next_state_list = unstack_states(next_states) + for i, s in enumerate(stream_list): + s.states = next_state_list[i] + s.decoder_out = next_decoder_out_list[i] + s.hyp = next_hyp_list[i] + + def get_texts(self, stream: Stream): + """ + Return text after decoding + Args: + stream: + Stream to be processed. + """ + hyp = stream.hyp[self.context_size :] # noqa + return self.sp.decode(hyp) + + +class ModifiedBeamSearch: + def __init__(self, blank_id: int, context_size: int): + self.blank_id = blank_id + self.context_size = context_size + + def init_stream(self, stream: Stream): + """ + Attributes to add to each stream + """ + hyp = [self.blank_id] * self.context_size + stream.hyps = Hypotheses([Hypothesis(ys=hyp, log_prob=0.0)]) + + @torch.no_grad() + def process( + self, + server: "StreamingServer", + stream_list: List[Stream], + ) -> None: + """Run the model on the given stream list and do modified_beam_search. + Args: + server: + An instance of `StreamingServer`. + stream_list: + A list of streams to be processed. It is changed in-place. + That is, the attribute `states` and `hyps` are + updated in-place. + """ + model = server.model + device = model.device + + segment_length = server.segment_length + chunk_length = server.chunk_length + + batch_size = len(stream_list) + + state_list = [] + hyps_list = [] + feature_list = [] + for s in stream_list: + state_list.append(s.states) + hyps_list.append(s.hyps) + + f = s.features[:chunk_length] + s.features = s.features[segment_length:] + + b = torch.cat(f, dim=0) + feature_list.append(b) + + features = torch.stack(feature_list, dim=0).to(device) + states = stack_states(state_list) + + features_length = torch.full( + (batch_size,), + fill_value=features.size(1), + device=device, + dtype=torch.int64, + ) + + (encoder_out, _, next_states) = model.encoder_streaming_forward( + features=features, + features_length=features_length, + states=states, + ) + # Note: There are no paddings for streaming ASR. Each stream + # has the same input number of frames, i.e., server.chunk_length. + next_hyps_list = streaming_modified_beam_search( + model=model, + encoder_out=encoder_out, + hyps=hyps_list, + ) + + next_state_list = unstack_states(next_states) + for i, s in enumerate(stream_list): + s.states = next_state_list[i] + s.hyps = next_hyps_list[i] + + def get_texts(self, stream: Stream): + hyp = stream.hyps.get_most_probable(True).ys[self.context_size :] + return self.sp.decode(hyp) diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/decode.py b/sherpa/bin/pruned_stateless_emformer_rnnt2/stream.py similarity index 82% rename from sherpa/bin/pruned_stateless_emformer_rnnt2/decode.py rename to sherpa/bin/pruned_stateless_emformer_rnnt2/stream.py index dfd94551..095883e2 100644 --- a/sherpa/bin/pruned_stateless_emformer_rnnt2/decode.py +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/stream.py @@ -15,14 +15,11 @@ # limitations under the License. import math -from typing import List, Optional +from typing import List -import k2 import torch from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature -from sherpa import Hypotheses, Hypothesis - def unstack_states( states: List[List[torch.Tensor]], @@ -120,29 +117,15 @@ class Stream(object): def __init__( self, context_size: int, - blank_id: int, initial_states: List[List[torch.Tensor]], - decoding_method: str = "greedy_search", - decoding_graph: Optional[k2.Fsa] = None, - decoder_out: Optional[torch.Tensor] = None, ) -> None: """ Args: context_size: Context size of the RNN-T decoder model. - blank_id: - Blank token ID of the BPE model. initial_states: The initial states of the Emformer model. Note that the state does not contain the batch dimension. - decoding_method: - The decoding method to use, currently, only greedy_search and - fast_beam_search are supported. - decoding_graph: - The Fsa based decoding graph for fast_beam_search. - decoder_out: - Optional. The initial decoder out corresponding to the decoder input - `[blank_id]*context_size`. Used only for greedy_search. """ self.feature_extractor = _create_streaming_feature_extractor() # It contains a list of 2-D tensors representing the feature frames. @@ -151,26 +134,6 @@ def __init__( self.num_fetched_frames = 0 self.states = initial_states - self.decoding_graph = decoding_graph - - if decoding_method == "fast_beam_search": - assert decoding_graph is not None - self.rnnt_decoding_stream = k2.RnntDecodingStream(decoding_graph) - self.hyp = [] - elif decoding_method == "greedy_search": - assert decoder_out is not None - self.decoder_out = decoder_out - self.hyp = [blank_id] * context_size - elif decoding_method == "modified_beam_search": - hyp = [blank_id] * context_size - self.hyps = Hypotheses([Hypothesis(ys=hyp, log_prob=0.0)]) - else: - # fmt: off - raise ValueError( - f"Decoding method : {decoding_method} is not supported." - ) - # fmt: on - self.processed_frames = 0 self.context_size = context_size self.log_eps = math.log(1e-10) diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py b/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py index 108405f9..1128383a 100755 --- a/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py @@ -34,21 +34,16 @@ import math import warnings from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional, Tuple +from typing import Optional, Tuple -import k2 import numpy as np import sentencepiece as spm import torch import websockets -from decode import Stream, stack_states, unstack_states +from beam_search import FastBeamSearch, GreedySearch, ModifiedBeamSearch +from stream import Stream, unstack_states -from sherpa import ( - RnntEmformerModel, - fast_beam_search_one_best, - streaming_greedy_search, - streaming_modified_beam_search, -) +from sherpa import RnntEmformerModel def get_args(): @@ -186,228 +181,6 @@ def get_args(): return parser.parse_args() -@torch.no_grad() -def run_model_and_do_greedy_search( - server: "StreamingServer", - stream_list: List[Stream], -) -> None: - """Run the model on the given stream list and do greedy search. - Args: - server: - An instance of `StreamingServer`. - stream_list: - A list of streams to be processed. It is changed in-place. - That is, the attribute `states`, `decoder_out`, and `hyp` are - updated in-place. - """ - model = server.model - device = model.device - segment_length = server.segment_length - chunk_length = server.chunk_length - - batch_size = len(stream_list) - - state_list = [] - feature_list = [] - - decoder_out_list = [] - hyp_list = [] - - for s in stream_list: - decoder_out_list.append(s.decoder_out) - hyp_list.append(s.hyp) - - state_list.append(s.states) - - f = s.features[:chunk_length] - s.features = s.features[segment_length:] - b = torch.cat(f, dim=0) - feature_list.append(b) - - features = torch.stack(feature_list, dim=0).to(device) - states = stack_states(state_list) - - decoder_out = torch.cat(decoder_out_list, dim=0) - - features_length = torch.full( - (batch_size,), - fill_value=features.size(1), - device=device, - dtype=torch.int64, - ) - - ( - encoder_out, - encoder_out_lens, - next_states, - ) = model.encoder_streaming_forward( # noqa - features=features, - features_length=features_length, - states=states, - ) - - # Note: It does not return the next_encoder_out_len since - # there are no paddings for streaming ASR. Each stream - # has the same input number of frames, i.e., server.chunk_length. - next_decoder_out, next_hyp_list = streaming_greedy_search( - model=model, - encoder_out=encoder_out, - decoder_out=decoder_out, - hyps=hyp_list, - ) - - next_decoder_out_list = next_decoder_out.split(1) - - next_state_list = unstack_states(next_states) - for i, s in enumerate(stream_list): - s.states = next_state_list[i] - s.decoder_out = next_decoder_out_list[i] - - s.hyp = next_hyp_list[i] - - -@torch.no_grad() -def run_model_and_do_fast_beam_search( - server: "StreamingServer", - stream_list: List[Stream], -) -> None: - """Run the model on the given stream list and do fast_beam_search. - Args: - server: - An instance of `StreamingServer`. - stream_list: - A list of streams to be processed. It is changed in-place. - That is, the attribute `states`, `decoder_out`, and `hyp` are - updated in-place. - """ - model = server.model - device = model.device - segment_length = server.segment_length - chunk_length = server.chunk_length - - batch_size = len(stream_list) - - state_list = [] - feature_list = [] - processed_frames_list = [] - - rnnt_decoding_streams_list = [] - rnnt_decoding_config = server.rnnt_decoding_config - - for s in stream_list: - rnnt_decoding_streams_list.append(s.rnnt_decoding_stream) - - state_list.append(s.states) - processed_frames_list.append(s.processed_frames) - - f = s.features[:chunk_length] - s.features = s.features[segment_length:] - b = torch.cat(f, dim=0) - feature_list.append(b) - - features = torch.stack(feature_list, dim=0).to(device) - states = stack_states(state_list) - - features_length = torch.full( - (batch_size,), - fill_value=features.size(1), - device=device, - dtype=torch.int64, - ) - - processed_frames = torch.tensor(processed_frames_list, device=device) - - ( - encoder_out, - encoder_out_lens, - next_states, - ) = model.encoder_streaming_forward( # noqa - features=features, - features_length=features_length, - states=states, - ) - - processed_lens = processed_frames + encoder_out_lens - next_hyp_list = fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - rnnt_decoding_config=rnnt_decoding_config, - rnnt_decoding_streams_list=rnnt_decoding_streams_list, - ) - - next_state_list = unstack_states(next_states) - for i, s in enumerate(stream_list): - s.states = next_state_list[i] - s.processed_frames += encoder_out_lens[i] - - s.hyp = next_hyp_list[i] - - -@torch.no_grad() -def run_model_and_do_modified_beam_search( - server: "StreamingServer", - stream_list: List[Stream], -) -> None: - """Run the model on the given stream list and do modified_beam_search. - Args: - server: - An instance of `StreamingServer`. - stream_list: - A list of streams to be processed. It is changed in-place. - That is, the attribute `states` and `hyps` are - updated in-place. - """ - model = server.model - device = model.device - - segment_length = server.segment_length - chunk_length = server.chunk_length - - batch_size = len(stream_list) - - state_list = [] - hyps_list = [] - feature_list = [] - for s in stream_list: - state_list.append(s.states) - hyps_list.append(s.hyps) - - f = s.features[:chunk_length] - s.features = s.features[segment_length:] - - b = torch.cat(f, dim=0) - feature_list.append(b) - - features = torch.stack(feature_list, dim=0).to(device) - states = stack_states(state_list) - - features_length = torch.full( - (batch_size,), - fill_value=features.size(1), - device=device, - dtype=torch.int64, - ) - - (encoder_out, _, next_states) = model.encoder_streaming_forward( - features=features, - features_length=features_length, - states=states, - ) - # Note: There are no paddings for streaming ASR. Each stream - # has the same input number of frames, i.e., server.chunk_length. - next_hyps_list = streaming_modified_beam_search( - model=model, - encoder_out=encoder_out, - hyps=hyps_list, - ) - - next_state_list = unstack_states(next_states) - for i, s in enumerate(stream_list): - s.states = next_state_list[i] - s.hyps = next_hyps_list[i] - - class StreamingServer(object): def __init__( self, @@ -489,42 +262,30 @@ def __init__( initial_states = self.model.get_encoder_init_states() self.initial_states = unstack_states(initial_states)[0] - self.decoding_method = decoding_method - assert self.decoding_method in [ - "greedy_search", - "fast_beam_search", - "modified_beam_search", - ] - - self.initial_decoder_out = None - self.decoding_graph = None if decoding_method == "fast_beam_search": - self.rnnt_decoding_config = k2.RnntDecodingConfig( + self.beam_search = FastBeamSearch( vocab_size=self.vocab_size, - decoder_history_len=self.context_size, + context_size=self.context_size, beam=beam, max_states=max_states, max_contexts=max_contexts, - ) - self.decoding_graph = k2.trivial_graph(self.vocab_size - 1, device) - self.run_nn_and_decode_func = run_model_and_do_fast_beam_search - elif decoding_method == "greedy_search": - decoder_input = torch.tensor( - [[self.blank_id] * self.context_size], device=device, - dtype=torch.int64, ) - initial_decoder_out = self.model.decoder_forward(decoder_input) - self.initial_decoder_out = self.model.forward_decoder_proj( - initial_decoder_out.squeeze(1) + elif decoding_method == "greedy_search": + self.beam_search = GreedySearch( + self.model, + device, ) - self.run_nn_and_decode_func = run_model_and_do_greedy_search elif decoding_method == "modified_beam_search": - self.run_nn_and_decode_func = run_model_and_do_modified_beam_search + self.beam_search = ModifiedBeamSearch( + self.blank_id, self.context_size + ) else: - raise ValueError(f"Unsupported method: {decoding_method}") + raise ValueError( + f"Decoding method {decoding_method} is not supported." + ) - self.decoding_method = decoding_method + self.beam_search.sp = self.sp self.num_active_paths = num_active_paths self.nn_pool = ThreadPoolExecutor( @@ -568,7 +329,7 @@ async def stream_consumer_task(self): loop = asyncio.get_running_loop() await loop.run_in_executor( self.nn_pool, - self.run_nn_and_decode_func, + self.beam_search.process, self, stream_list, ) @@ -663,29 +424,14 @@ async def handle_connection_impl( ) stream = Stream( context_size=self.context_size, - blank_id=self.blank_id, initial_states=self.initial_states, - decoding_method=self.decoding_method, - decoding_graph=self.decoding_graph, - decoder_out=self.initial_decoder_out, ) - async def send_results(): - if self.decoding_method == "greedy_search": - hyp = stream.hyp[self.context_size :] # noqa - elif self.decoding_method == "modified_beam_search": - hyp = stream.hyps.get_most_probable(True).ys - hyp = hyp[self.context_size :] # noqa - elif self.decoding_method == "fast_beam_search": - hyp = stream.hyp - else: - # fmt: off - raise ValueError( - "Unsupported method " f"{self.decoding_method}" - ) - # fmt: on + self.beam_search.init_stream(stream) - await socket.send(f"{self.sp.decode(hyp)}") + async def send_results(): + result = self.beam_search.get_texts(stream) + await socket.send(f"{result}") while True: samples = await self.recv_audio_samples(socket) @@ -698,7 +444,6 @@ async def send_results(): while len(stream.features) > self.chunk_length: await self.compute_and_decode(stream) - await send_results() stream.input_finished() @@ -765,7 +510,6 @@ def main(): max_message_size = args.max_message_size max_queue_size = args.max_queue_size max_active_connections = args.max_active_connections - decoding_method = args.decoding_method num_active_paths = args.num_active_paths assert decoding_method in ( @@ -816,8 +560,8 @@ def main(): """ if __name__ == "__main__": - # fmt: off + # fmt:off formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa - # fmt: on + # fmt:on logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/sherpa/bin/pruned_transducer_statelessX/beam_search.py b/sherpa/bin/pruned_transducer_statelessX/beam_search.py new file mode 100644 index 00000000..effa5d7b --- /dev/null +++ b/sherpa/bin/pruned_transducer_statelessX/beam_search.py @@ -0,0 +1,133 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import List + +import torch +from torch.nn.utils.rnn import pad_sequence + +from sherpa import RnntConformerModel, greedy_search, modified_beam_search + +LOG_EPS = math.log(1e-10) + + +class GreedySearchOffline: + def __init__(self): + pass + + @torch.no_grad() + def process( + self, + model: "RnntConformerModel", + features: List[torch.Tensor], + ) -> List[List[int]]: + """ + Args: + model: + RNN-T model decoder model + + features: + A list of 2-D tensors. Each entry is of shape + (num_frames, feature_dim). + Returns: + Return a list-of-list containing the decoding token IDs. + """ + features_length = torch.tensor( + [f.size(0) for f in features], + dtype=torch.int64, + ) + features = pad_sequence( + features, + batch_first=True, + padding_value=LOG_EPS, + ) + + device = model.device + features = features.to(device) + features_length = features_length.to(device) + + encoder_out, encoder_out_length = model.encoder( + features=features, + features_length=features_length, + ) + + hyp_tokens = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_length=encoder_out_length.cpu(), + ) + + return hyp_tokens + + +class ModifiedBeamSearchOffline: + def __init__(self, num_active_paths: int): + """ + Args: + num_active_paths: + Used only when decoding_method is modified_beam_search. + It specifies number of active paths for each utterance. Due to + merging paths with identical token sequences, the actual number + may be less than "num_active_paths". + """ + self.num_active_paths = num_active_paths + + @torch.no_grad() + def process( + self, + model: "RnntConformerModel", + features: List[torch.Tensor], + ) -> List[List[int]]: + """Run RNN-T model with the given features and use greedy search + to decode the output of the model. + + Args: + model: + The RNN-T model. + features: + A list of 2-D tensors. Each entry is of shape + (num_frames, feature_dim). + Returns: + Return a list-of-list containing the decoding token IDs. + """ + features_length = torch.tensor( + [f.size(0) for f in features], + dtype=torch.int64, + ) + features = pad_sequence( + features, + batch_first=True, + padding_value=LOG_EPS, + ) + + device = model.device + features = features.to(device) + features_length = features_length.to(device) + + encoder_out, encoder_out_length = model.encoder( + features=features, + features_length=features_length, + ) + + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_length=encoder_out_length.cpu(), + num_active_paths=self.num_active_paths, + ) + return hyp_tokens diff --git a/sherpa/bin/pruned_transducer_statelessX/decode.py b/sherpa/bin/pruned_transducer_statelessX/decode.py deleted file mode 100644 index 9c08bcc7..00000000 --- a/sherpa/bin/pruned_transducer_statelessX/decode.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List - -import torch -from torch.nn.utils.rnn import pad_sequence - -from sherpa import RnntConformerModel, greedy_search, modified_beam_search - -LOG_EPS = math.log(1e-10) - - -@torch.no_grad() -def run_model_and_do_greedy_search( - model: RnntConformerModel, - features: List[torch.Tensor], -) -> List[List[int]]: - """Run RNN-T model with the given features and use greedy search - to decode the output of the model. - - Args: - model: - The RNN-T model. - features: - A list of 2-D tensors. Each entry is of shape - (num_frames, feature_dim). - Returns: - Return a list-of-list containing the decoding token IDs. - """ - features_length = torch.tensor( - [f.size(0) for f in features], - dtype=torch.int64, - ) - features = pad_sequence( - features, - batch_first=True, - padding_value=LOG_EPS, - ) - - device = model.device - features = features.to(device) - features_length = features_length.to(device) - - encoder_out, encoder_out_length = model.encoder( - features=features, - features_length=features_length, - ) - - hyp_tokens = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_length=encoder_out_length.cpu(), - ) - return hyp_tokens - - -@torch.no_grad() -def run_model_and_do_modified_beam_search( - model: RnntConformerModel, - features: List[torch.Tensor], - num_active_paths: int, -) -> List[List[int]]: - """Run RNN-T model with the given features and use greedy search - to decode the output of the model. - - Args: - model: - The RNN-T model. - features: - A list of 2-D tensors. Each entry is of shape - (num_frames, feature_dim). - num_active_paths: - Used only when decoding_method is modified_beam_search. - It specifies number of active paths for each utterance. Due to - merging paths with identical token sequences, the actual number - may be less than "num_active_paths". - Returns: - Return a list-of-list containing the decoding token IDs. - """ - features_length = torch.tensor( - [f.size(0) for f in features], - dtype=torch.int64, - ) - features = pad_sequence( - features, - batch_first=True, - padding_value=LOG_EPS, - ) - - device = model.device - features = features.to(device) - features_length = features_length.to(device) - - encoder_out, encoder_out_length = model.encoder( - features=features, - features_length=features_length, - ) - - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_length=encoder_out_length.cpu(), - num_active_paths=num_active_paths, - ) - return hyp_tokens diff --git a/sherpa/bin/pruned_transducer_statelessX/offline_asr.py b/sherpa/bin/pruned_transducer_statelessX/offline_asr.py index c0303cad..348d4d79 100755 --- a/sherpa/bin/pruned_transducer_statelessX/offline_asr.py +++ b/sherpa/bin/pruned_transducer_statelessX/offline_asr.py @@ -85,7 +85,6 @@ $wav3 """ # noqa import argparse -import functools import logging from typing import List, Optional, Union @@ -94,10 +93,7 @@ import sentencepiece as spm import torch import torchaudio -from decode import ( - run_model_and_do_greedy_search, - run_model_and_do_modified_beam_search, -) +from beam_search import GreedySearchOffline, ModifiedBeamSearchOffline from sherpa import RnntConformerModel @@ -262,20 +258,16 @@ def __init__( "greedy_search", "modified_beam_search", ), decoding_method + if decoding_method == "greedy_search": - nn_and_decoding_func = run_model_and_do_greedy_search + self.beam_search = GreedySearchOffline() elif decoding_method == "modified_beam_search": - nn_and_decoding_func = functools.partial( - run_model_and_do_modified_beam_search, - num_active_paths=num_active_paths, - ) + self.beam_search = ModifiedBeamSearchOffline(num_active_paths) else: raise ValueError( - f"Unsupported decoding_method: {decoding_method} " - "Please use greedy_search or modified_beam_search" + f"Decoding method {decoding_method} is not supported." ) - self.nn_and_decoding_func = nn_and_decoding_func self.device = device def _build_feature_extractor( @@ -325,7 +317,7 @@ def decode_waves(self, waves: List[torch.Tensor]) -> List[List[str]]: waves = [w.to(self.device) for w in waves] features = self.feature_extractor(waves) - tokens = self.nn_and_decoding_func(self.model, features) + tokens = self.beam_search.process(self.model, features) if hasattr(self, "sp"): results = self.sp.decode(tokens) @@ -429,7 +421,6 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220609) - # fmt: off formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa # fmt: on diff --git a/sherpa/bin/pruned_transducer_statelessX/offline_server.py b/sherpa/bin/pruned_transducer_statelessX/offline_server.py index b4c69703..76d39e36 100755 --- a/sherpa/bin/pruned_transducer_statelessX/offline_server.py +++ b/sherpa/bin/pruned_transducer_statelessX/offline_server.py @@ -27,7 +27,6 @@ import argparse import asyncio -import functools import http import logging import warnings @@ -40,10 +39,7 @@ import sentencepiece as spm import torch import websockets -from decode import ( - run_model_and_do_greedy_search, - run_model_and_do_modified_beam_search, -) +from beam_search import GreedySearchOffline, ModifiedBeamSearchOffline from sherpa import RnntConformerModel @@ -287,25 +283,15 @@ def __init__( self.current_active_connections = 0 - assert decoding_method in ( - "greedy_search", - "modified_beam_search", - ), decoding_method if decoding_method == "greedy_search": - nn_and_decoding_func = run_model_and_do_greedy_search + self.beam_search = GreedySearchOffline() elif decoding_method == "modified_beam_search": - nn_and_decoding_func = functools.partial( - run_model_and_do_modified_beam_search, - num_active_paths=num_active_paths, - ) + self.beam_search = ModifiedBeamSearchOffline(num_active_paths) else: raise ValueError( - f"Unsupported decoding_method: {decoding_method} " - "Please use greedy_search or modified_beam_search" + f"Decoding method {decoding_method} is not supported." ) - self.nn_and_decoding_func = nn_and_decoding_func - self.decoding_method = decoding_method self.num_active_paths = num_active_paths def _build_feature_extractor(self) -> kaldifeat.OfflineFeature: @@ -484,7 +470,7 @@ async def feature_consumer_task(self): hyp_tokens = await loop.run_in_executor( self.nn_pool, - self.nn_and_decoding_func, + self.beam_search.process, model, feature_list, ) @@ -669,10 +655,9 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220519) - - # fmt: off + # fmt:off formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa - # fmt: on + # fmt:on logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/sherpa/bin/streaming_pruned_transducer_statelessX/beam_search.py b/sherpa/bin/streaming_pruned_transducer_statelessX/beam_search.py new file mode 100644 index 00000000..d07b4511 --- /dev/null +++ b/sherpa/bin/streaming_pruned_transducer_statelessX/beam_search.py @@ -0,0 +1,299 @@ +from typing import List + +import k2 +import torch +from stream import Stream + +from sherpa import fast_beam_search_one_best, streaming_greedy_search + + +class FastBeamSearch: + def __init__( + self, + vocab_size: int, + context_size: int, + beam: int, + max_states: int, + max_contexts: int, + device: torch.device, + ): + """ + Args: + vocab_size: + Vocabularize of the BPE + context_size: + Context size of the RNN-T decoder model. + beam: + The beam for fast_beam_search decoding. + max_states: + The max_states for fast_beam_search decoding. + max_contexts: + The max_contexts for fast_beam_search decoding. + device: + Device on which the computation will occur + """ + self.rnnt_decoding_config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + self.decoding_graph = k2.trivial_graph(vocab_size - 1, device) + self.device = device + self.context_size = context_size + + def init_stream(self, stream: Stream): + """ + Attributes to add to each stream + """ + stream.rnnt_decoding_stream = k2.RnntDecodingStream(self.decoding_graph) + stream.hyp = [] + + @torch.no_grad() + def process( + self, + server: "StreamingServer", + stream_list: List[Stream], + ) -> None: + """Run the model on the given stream list and do search with fast_beam_search + method. + Args: + server: + An instance of `StreamingServer`. + stream_list: + A list of streams to be processed. It is changed in-place. + That is, the attribute `states` and `hyp` are + updated in-place. + """ + model = server.model + # Note: chunk_length is in frames before subsampling + chunk_length = server.chunk_length + subsampling_factor = server.subsampling_factor + # Note: chunk_size, left_context and right_context are in frames + # after subsampling + chunk_size = server.decode_chunk_size + left_context = server.decode_left_context + right_context = server.decode_right_context + + batch_size = len(stream_list) + + state_list = [] + feature_list = [] + processed_frames_list = [] + + rnnt_decoding_streams_list = [] + rnnt_decoding_config = self.rnnt_decoding_config + for s in stream_list: + rnnt_decoding_streams_list.append(s.rnnt_decoding_stream) + state_list.append(s.states) + processed_frames_list.append(s.processed_frames) + f = s.features[:chunk_length] + s.features = s.features[chunk_size * subsampling_factor :] + b = torch.cat(f, dim=0) + feature_list.append(b) + + features = torch.stack(feature_list, dim=0).to(self.device) + + states = [ + torch.stack([x[0] for x in state_list], dim=2), + torch.stack([x[1] for x in state_list], dim=2), + ] + + features_length = torch.full( + (batch_size,), + fill_value=features.size(1), + device=self.device, + dtype=torch.int64, + ) + + processed_frames = torch.tensor( + processed_frames_list, device=self.device + ) + + ( + encoder_out, + encoder_out_lens, + next_states, + ) = model.encoder_streaming_forward( + features=features, + features_length=features_length, + states=states, + processed_frames=processed_frames, + left_context=left_context, + right_context=right_context, + ) + + processed_lens = processed_frames + encoder_out_lens + next_hyp_list = fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + rnnt_decoding_config=rnnt_decoding_config, + rnnt_decoding_streams_list=rnnt_decoding_streams_list, + ) + + next_state_list = [ + torch.unbind(next_states[0], dim=2), + torch.unbind(next_states[1], dim=2), + ] + + for i, s in enumerate(stream_list): + s.states = [next_state_list[0][i], next_state_list[1][i]] + s.processed_frames += encoder_out_lens[i] + s.hyp = next_hyp_list[i] + + def get_texts(self, stream: Stream): + """ + Return text after decoding + Args: + stream: + Stream to be processed. + """ + if hasattr(self, "sp"): + result = self.sp.decode(stream.hyp) + else: + result = [self.token_table[i] for i in stream.hyp] + return result + + +class GreedySearch: + def __init__(self, model: "RnntConformerModel", device: torch.device): + """ + Args: + model: + RNN-T model decoder model + device: + Device on which the computation will occur + """ + + self.blank_id = model.blank_id + self.context_size = model.context_size + self.device = device + + decoder_input = torch.tensor( + [[self.blank_id] * self.context_size], + device=self.device, + dtype=torch.int64, + ) + initial_decoder_out = model.decoder_forward(decoder_input) + self.initial_decoder_out = model.forward_decoder_proj( + initial_decoder_out.squeeze(1) + ) + + def init_stream(self, stream: Stream): + """ + Attributes to add to each stream + """ + stream.decoder_out = self.initial_decoder_out + stream.hyp = [self.blank_id] * self.context_size + + @torch.no_grad() + def process( + self, + server: "StreamingServer", + stream_list: List[Stream], + ) -> None: + """Run the model on the given stream list and do search with greedy_search + method. + Args: + server: + An instance of `StreamingServer`. + stream_list: + A list of streams to be processed. It is changed in-place. + That is, the attribute `states` and `hyp` are + updated in-place. + """ + model = server.model + device = model.device + # Note: chunk_length is in frames before subsampling + chunk_length = server.chunk_length + subsampling_factor = server.subsampling_factor + # Note: chunk_size, left_context and right_context are in frames + # after subsampling + chunk_size = server.decode_chunk_size + left_context = server.decode_left_context + right_context = server.decode_right_context + + batch_size = len(stream_list) + + state_list, feature_list, processed_frames_list = [], [], [] + decoder_out_list, hyp_list = [], [] + + for s in stream_list: + decoder_out_list.append(s.decoder_out) + hyp_list.append(s.hyp) + state_list.append(s.states) + processed_frames_list.append(s.processed_frames) + f = s.features[:chunk_length] + s.features = s.features[chunk_size * subsampling_factor :] + b = torch.cat(f, dim=0) + feature_list.append(b) + + features = torch.stack(feature_list, dim=0).to(device) + + states = [ + torch.stack([x[0] for x in state_list], dim=2), + torch.stack([x[1] for x in state_list], dim=2), + ] + + decoder_out = torch.cat(decoder_out_list, dim=0) + + features_length = torch.full( + (batch_size,), + fill_value=features.size(1), + device=device, + dtype=torch.int64, + ) + + processed_frames = torch.tensor(processed_frames_list, device=device) + + ( + encoder_out, + encoder_out_lens, + next_states, + ) = model.encoder_streaming_forward( + features=features, + features_length=features_length, + states=states, + processed_frames=processed_frames, + left_context=left_context, + right_context=right_context, + ) + + # Note: It does not return the next_encoder_out_len since + # there are no paddings for streaming ASR. Each stream + # has the same input number of frames, i.e., server.chunk_length. + next_decoder_out, next_hyp_list = streaming_greedy_search( + model=model, + encoder_out=encoder_out, + decoder_out=decoder_out, + hyps=hyp_list, + ) + + next_state_list = [ + torch.unbind(next_states[0], dim=2), + torch.unbind(next_states[1], dim=2), + ] + next_decoder_out_list = next_decoder_out.split(1) + + for i, s in enumerate(stream_list): + s.states = [next_state_list[0][i], next_state_list[1][i]] + s.processed_frames += encoder_out_lens[i] + s.decoder_out = next_decoder_out_list[i] + s.hyp = next_hyp_list[i] + + def get_texts(self, stream: Stream): + """ + Return text after decoding + Args: + stream: + Stream to be processed. + """ + if hasattr(self, "sp"): + result = self.sp.decode(stream.hyp[self.context_size :]) # noqa + else: + result = [ + self.token_table[i] for i in stream.hyp[self.context_size :] + ] # noqa + return result diff --git a/sherpa/bin/streaming_pruned_transducer_statelessX/decode.py b/sherpa/bin/streaming_pruned_transducer_statelessX/stream.py similarity index 75% rename from sherpa/bin/streaming_pruned_transducer_statelessX/decode.py rename to sherpa/bin/streaming_pruned_transducer_statelessX/stream.py index 58dd8894..435e18f6 100644 --- a/sherpa/bin/streaming_pruned_transducer_statelessX/decode.py +++ b/sherpa/bin/streaming_pruned_transducer_statelessX/stream.py @@ -15,9 +15,8 @@ # limitations under the License. import math -from typing import List, Optional +from typing import List -import k2 import torch from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature @@ -45,33 +44,15 @@ class Stream(object): def __init__( self, context_size: int, - blank_id: int, initial_states: List[torch.Tensor], - decoding_method: str = "greedy_search", - decoding_graph: Optional[k2.Fsa] = None, - decoder_out: Optional[torch.Tensor] = None, ) -> None: """ Args: context_size: Context size of the RNN-T decoder model. - blank_id: - Blank token ID of the BPE model. initial_states: The initial states of the Conformer model. Note that the state does not contain the batch dimension. - decoding_method: - The decoding method to use, currently, only greedy_search and - fast_beam_search are supported. - decoding_graph: - The Fsa based decoding graph for fast_beam_search. Only used when - decoding_method is fast_beam_search, and it can not be None if - decoding_method is fast_beam_search. - decoder_out: - The initial decoder out corresponding to the decoder input - `[blank_id]*context_size`. Only used when decoding_method is - greedy_search, and it can not be None if decoding_method is - greedy_search. """ self.feature_extractor = _create_streaming_feature_extractor() # It contains a list of 2-D tensors representing the feature frames. @@ -80,21 +61,6 @@ def __init__( self.num_fetched_frames = 0 self.states = initial_states - self.decoding_graph = decoding_graph - if decoding_method == "fast_beam_search": - assert decoding_graph is not None - self.rnnt_decoding_stream = k2.RnntDecodingStream(decoding_graph) - self.hyp = [] - elif decoding_method == "greedy_search": - assert decoder_out is not None - self.decoder_out = decoder_out - self.hyp = [blank_id] * context_size - else: - # fmt: off - raise ValueError( - f"Decoding method {decoding_method} is not supported." - ) - # fmt: on # The number of frames (after subsampling) been processed. self.processed_frames = 0 diff --git a/sherpa/bin/streaming_pruned_transducer_statelessX/streaming_server.py b/sherpa/bin/streaming_pruned_transducer_statelessX/streaming_server.py index 471d4269..b1af14e5 100755 --- a/sherpa/bin/streaming_pruned_transducer_statelessX/streaming_server.py +++ b/sherpa/bin/streaming_pruned_transducer_statelessX/streaming_server.py @@ -35,20 +35,17 @@ import math import warnings from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional, Tuple +from typing import Optional, Tuple import k2 import numpy as np import sentencepiece as spm import torch import websockets -from decode import Stream +from beam_search import FastBeamSearch, GreedySearch +from stream import Stream -from sherpa import ( - RnntConformerModel, - fast_beam_search_one_best, - streaming_greedy_search, -) +from sherpa import RnntConformerModel def get_args(): @@ -212,127 +209,6 @@ def get_args(): return parser.parse_args() -@torch.no_grad() -def run_model_and_do_search( - server: "StreamingServer", - stream_list: List[Stream], -) -> None: - """Run the model on the given stream list and do search with given decoding - method. - Args: - server: - An instance of `StreamingServer`. - stream_list: - A list of streams to be processed. It is changed in-place. - That is, the attribute `states`, `decoder_out`, and `hyp` are - updated in-place. - """ - model = server.model - device = model.device - # Note: chunk_length is in frames before subsampling - chunk_length = server.chunk_length - subsampling_factor = server.subsampling_factor - # Note: chunk_size, left_context and right_context are in frames - # after subsampling - chunk_size = server.decode_chunk_size - left_context = server.decode_left_context - right_context = server.decode_right_context - decoding_method = server.decoding_method - - batch_size = len(stream_list) - - state_list = [] - feature_list = [] - processed_frames_list = [] - if decoding_method == "greedy_search": - decoder_out_list = [] - hyp_list = [] - else: - rnnt_decoding_streams_list = [] - rnnt_decoding_config = server.rnnt_decoding_config - for s in stream_list: - if decoding_method == "greedy_search": - decoder_out_list.append(s.decoder_out) - hyp_list.append(s.hyp) - if decoding_method == "fast_beam_search": - rnnt_decoding_streams_list.append(s.rnnt_decoding_stream) - state_list.append(s.states) - processed_frames_list.append(s.processed_frames) - f = s.features[:chunk_length] - s.features = s.features[chunk_size * subsampling_factor :] # noqa - b = torch.cat(f, dim=0) - feature_list.append(b) - - features = torch.stack(feature_list, dim=0).to(device) - - states = [ - torch.stack([x[0] for x in state_list], dim=2), - torch.stack([x[1] for x in state_list], dim=2), - ] - - if decoding_method == "greedy_search": - decoder_out = torch.cat(decoder_out_list, dim=0) - - features_length = torch.full( - (batch_size,), - fill_value=features.size(1), - device=device, - dtype=torch.int64, - ) - - processed_frames = torch.tensor(processed_frames_list, device=device) - - # fmt: off - ( - encoder_out, - encoder_out_lens, - next_states, - ) = model.encoder_streaming_forward( - features=features, - features_length=features_length, - states=states, - processed_frames=processed_frames, - left_context=left_context, - right_context=right_context, - ) - # fmt: on - - if decoding_method == "fast_beam_search": - processed_lens = processed_frames + encoder_out_lens - next_hyp_list = fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - rnnt_decoding_config=rnnt_decoding_config, - rnnt_decoding_streams_list=rnnt_decoding_streams_list, - ) - elif decoding_method == "greedy_search": - # Note: It does not return the next_encoder_out_len since - # there are no paddings for streaming ASR. Each stream - # has the same input number of frames, i.e., server.chunk_length. - next_decoder_out, next_hyp_list = streaming_greedy_search( - model=model, - encoder_out=encoder_out, - decoder_out=decoder_out, - hyps=hyp_list, - ) - else: - raise ValueError(f"Decoding method {decoding_method} is not supported.") - - next_state_list = [ - torch.unbind(next_states[0], dim=2), - torch.unbind(next_states[1], dim=2), - ] - if decoding_method == "greedy_search": - next_decoder_out_list = next_decoder_out.split(1) - for i, s in enumerate(stream_list): - s.states = [next_state_list[0][i], next_state_list[1][i]] - s.processed_frames += encoder_out_lens[i] - if decoding_method == "greedy_search": - s.decoder_out = next_decoder_out_list[i] - s.hyp = next_hyp_list[i] - - class StreamingServer(object): def __init__( self, @@ -376,8 +252,8 @@ def __init__( max_contexts: The max_contexts for fast_beam_search decoding. decoding_method: - The decoding method to use, can be either greedy_search or - fast_beam_search. + The decoding method to use, can be either greedy_search + or fast_beam_search. nn_pool_size: Number of threads for the thread pool that is responsible for neural network computation and decoding. @@ -432,36 +308,34 @@ def __init__( self.initial_states = self.model.get_encoder_init_states( self.decode_left_context ) - self.decoding_method = decoding_method - self.initial_decoder_out = None - self.decoding_graph = None if decoding_method == "fast_beam_search": - self.rnnt_decoding_config = k2.RnntDecodingConfig( + self.beam_search = FastBeamSearch( vocab_size=self.vocab_size, - decoder_history_len=self.context_size, + context_size=self.context_size, beam=beam, max_states=max_states, max_contexts=max_contexts, - ) - self.decoding_graph = k2.trivial_graph(self.vocab_size - 1, device) - elif decoding_method == "greedy_search": - decoder_input = torch.tensor( - [[self.blank_id] * self.context_size], device=device, - dtype=torch.int64, ) - initial_decoder_out = self.model.decoder_forward(decoder_input) - self.initial_decoder_out = self.model.forward_decoder_proj( - initial_decoder_out.squeeze(1) + elif decoding_method == "greedy_search": + self.beam_search = GreedySearch( + self.model, + device, ) else: - # fmt: off raise ValueError( f"Decoding method {decoding_method} is not supported." ) - # fmt: on + + if bpe_model_filename: + self.beam_search.sp = spm.SentencePieceProcessor() + self.beam_search.sp.load(bpe_model_filename) + else: + self.beam_search.token_table = k2.SymbolTable.from_file( + token_filename + ) self.nn_pool = ThreadPoolExecutor( max_workers=nn_pool_size, @@ -504,7 +378,7 @@ async def stream_consumer_task(self): loop = asyncio.get_running_loop() await loop.run_in_executor( self.nn_pool, - run_model_and_do_search, + self.beam_search.process, self, stream_list, ) @@ -565,7 +439,7 @@ async def handle_connection( socket: websockets.WebSocketServerProtocol, ): """Receive audio samples from the client, process it, and send - deocoding result back to the client. + decoding result back to the client. Args: socket: @@ -599,13 +473,11 @@ async def handle_connection_impl( ) stream = Stream( context_size=self.context_size, - blank_id=self.blank_id, initial_states=self.initial_states, - decoding_method=self.decoding_method, - decoding_graph=self.decoding_graph, - decoder_out=self.initial_decoder_out, ) + self.beam_search.init_stream(stream) + while True: samples = await self.recv_audio_samples(socket) if samples is None: @@ -617,31 +489,8 @@ async def handle_connection_impl( while len(stream.features) > self.chunk_length: await self.compute_and_decode(stream) - if self.decoding_method == "greedy_search": - if hasattr(self, "sp"): - # fmt: off - result = self.sp.decode( - stream.hyp[self.context_size :] # noqa - ) - # fmt: on - else: - # fmt: off - result = [ - self.token_table[i] - for i in stream.hyp[self.context_size :] # noqa - ] - # fmt: on - await socket.send(result) - elif self.decoding_method == "fast_beam_search": - if hasattr(self, "sp"): - result = self.sp.decode(stream.hyp) - else: - result = [self.token_table[i] for i in stream.hyp] - await socket.send(result) - else: - raise ValueError( - f"Decoding method {self.decoding_method} is not supported." # noqa - ) + result = self.beam_search.get_texts(stream) + await socket.send(result) stream.input_finished() while len(stream.features) > self.chunk_length: @@ -653,29 +502,9 @@ async def handle_connection_impl( await self.compute_and_decode(stream) stream.features = [] - if self.decoding_method == "greedy_search": - if hasattr(self, "sp"): - result = self.sp.decode(stream.hyp[self.context_size :]) # noqa - else: - # fmt: off - result = [ - self.token_table[i] - for i in stream.hyp[self.context_size :] # noqa - ] - # fmt: on - await socket.send(result) - await socket.send("Done") - elif self.decoding_method == "fast_beam_search": - if hasattr(self, "sp"): - result = self.sp.decode(stream.hyp) - else: - result = [self.token_table[i] for i in stream.hyp] - await socket.send(result) - await socket.send("Done") - else: - raise ValueError( - f"Decoding method {self.decoding_method} is not supported." - ) + result = self.beam_search.get_texts(stream) + await socket.send(result) + await socket.send("Done") async def recv_audio_samples( self, @@ -775,8 +604,8 @@ def main(): """ if __name__ == "__main__": - # fmt: off + # fmt:off formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa - # fmt: on + # fmt:on logging.basicConfig(format=formatter, level=logging.INFO) main()