Skip to content

Commit

Permalink
Add fast_beam_search_nbest to streaming and offline models
Browse files Browse the repository at this point in the history
  • Loading branch information
ezerhouni committed Jul 25, 2022
1 parent 6d5abe4 commit d902087
Show file tree
Hide file tree
Showing 15 changed files with 340 additions and 335 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-streaming-conformer-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run-streaming-conv-emformer-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run-streaming-test-windows-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
python-version: ["3.7", "3.8", "3.9"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search", "fast_beam_search_nbest"]
steps:
- uses: actions/checkout@v2
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run-streaming-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search", "fast_beam_search_nbest"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.11.0", "1.7.1"]
torchaudio: ["0.11.0", "0.7.2"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest"]
exclude:
- torch: "1.11.0"
torchaudio: "0.7.2"
Expand Down
108 changes: 71 additions & 37 deletions sherpa/bin/conv_emformer_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,47 @@
import torch
from stream import Stream, stack_states, unstack_states

from sherpa import fast_beam_search_one_best, streaming_greedy_search
from sherpa import (
VALID_FAST_BEAM_SEARCH_METHOD,
fast_beam_search_nbest,
fast_beam_search_one_best,
streaming_greedy_search,
)


class FastBeamSearch:
def __init__(
self,
vocab_size: int,
context_size: int,
beam: int,
max_states: int,
max_contexts: int,
beam_search_params: dict,
device: torch.device,
):
"""
Args:
vocab_size:
Vocabularize of the BPE
context_size:
Context size of the RNN-T decoder model.
beam:
The beam for fast_beam_search decoding.
max_states:
The max_states for fast_beam_search decoding.
max_contexts:
The max_contexts for fast_beam_search decoding.
beam_search_params
Dictionary containing all the parameters for beam search.
device:
Device on which the computation will occur
"""

decoding_method = beam_search_params["decoding_method"]
assert (
decoding_method in VALID_FAST_BEAM_SEARCH_METHOD
), f"{decoding_method} is not a valid search method"

self.decoding_method = decoding_method
self.rnnt_decoding_config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
vocab_size=beam_search_params["vocab_size"],
decoder_history_len=beam_search_params["context_size"],
beam=beam_search_params["beam"],
max_states=beam_search_params["max_states"],
max_contexts=beam_search_params["max_contexts"],
)
self.decoding_graph = k2.trivial_graph(
beam_search_params["vocab_size"] - 1, device
)
self.decoding_graph = k2.trivial_graph(vocab_size - 1, device)
self.device = device
self.context_size = context_size
self.context_size = beam_search_params["context_size"]
self.beam_search_params = beam_search_params

def init_stream(self, stream: Stream):
"""
Expand Down Expand Up @@ -114,13 +117,30 @@ def process(
)

processed_lens = (num_processed_frames >> 2) + encoder_out_lens
next_hyp_list = fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
rnnt_decoding_config=rnnt_decoding_config,
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
)
if self.decoding_method == "fast_beam_search_nbest":
next_hyp_list = fast_beam_search_nbest(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
rnnt_decoding_config=rnnt_decoding_config,
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
num_paths=self.beam_search_params["num_paths"],
nbest_scale=self.beam_search_params["nbest_scale"],
use_double_scores=True,
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search":
next_hyp_list = fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
rnnt_decoding_config=rnnt_decoding_config,
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
)
else:
raise NotImplementedError(
f"{self.decoding_method} is not implemented"
)

next_state_list = unstack_states(next_states)
for i, s in enumerate(stream_list):
Expand All @@ -138,24 +158,34 @@ def get_texts(self, stream: Stream):


class GreedySearch:
def __init__(self, model: "RnntConvEmformerModel", device: torch.device):
def __init__(
self,
model: "RnntConvEmformerModel",
beam_search_params: dict,
device: torch.device,
):
"""
Args:
model:
RNN-T model decoder model
beam_search_params:
Dictionary containing all the parameters for beam search.
device:
Device on which the computation will occur
"""

self.blank_id = model.blank_id
self.context_size = model.context_size
self.device = device
self.beam_search_params = beam_search_params
self.device = device

decoder_input = torch.tensor(
[[self.blank_id] * self.context_size],
[
[self.beam_search_params["blank_id"]]
* self.beam_search_params["context_size"]
],
device=self.device,
dtype=torch.int64,
)

initial_decoder_out = model.decoder_forward(decoder_input)
self.initial_decoder_out = model.forward_decoder_proj(
initial_decoder_out.squeeze(1)
Expand All @@ -166,7 +196,9 @@ def init_stream(self, stream: Stream):
Attributes to add to each stream
"""
stream.decoder_out = self.initial_decoder_out
stream.hyp = [self.blank_id] * self.context_size
stream.hyp = [
self.beam_search_params["blank_id"]
] * self.beam_search_params["context_size"]

@torch.no_grad()
def process(
Expand Down Expand Up @@ -257,4 +289,6 @@ def get_texts(self, stream: Stream):
stream:
Stream to be processed.
"""
return self.sp.decode(stream.hyp[self.context_size :])
return self.sp.decode(
stream.hyp[self.beam_search_params["context_size"] :]
)
Loading

0 comments on commit d902087

Please sign in to comment.