Skip to content

Commit

Permalink
Merge pull request #83 from ezerhouni/add-fast-beam-search-nbest-lg
Browse files Browse the repository at this point in the history
Add fast_beam_search_nbest_LG to streaming models
  • Loading branch information
ezerhouni authored Jul 28, 2022
2 parents adba776 + f8e395c commit eef8987
Show file tree
Hide file tree
Showing 15 changed files with 542 additions and 25 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/run-streaming-conformer-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down Expand Up @@ -127,6 +127,7 @@ jobs:
--nn-pool-size 1 \
--nn-model-filename ./icefall-asr-librispeech-pruned-stateless-streaming-conformer-rnnt4-2022-06-10/exp/cpu_jit-epoch-29-avg-6_torch-${{ matrix.torch }}.pt \
--decoding-method ${{ matrix.decoding }} \
--lang-dir ./icefall-asr-librispeech-pruned-stateless-streaming-conformer-rnnt4-2022-06-10/data/lang_bpe_500/ \
--bpe-model-filename ./icefall-asr-librispeech-pruned-stateless-streaming-conformer-rnnt4-2022-06-10/data/lang_bpe_500/bpe.model &
echo "Sleep 10 seconds to wait for the server startup"
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/run-streaming-conv-emformer-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down Expand Up @@ -127,6 +127,7 @@ jobs:
--nn-pool-size 1 \
--nn-model-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/cpu-jit-epoch-30-avg-10-torch-${{ matrix.torch }}.pt \
--decoding-method ${{ matrix.decoding }} \
--lang-dir ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/ \
--bpe-model-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/bpe.model &
echo "Sleep 10 seconds to wait for the server startup"
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/run-streaming-test-windows-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
python-version: ["3.7", "3.8", "3.9"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search", "fast_beam_search_nbest"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
steps:
- uses: actions/checkout@v2
with:
Expand Down Expand Up @@ -93,6 +93,7 @@ jobs:
--max-wait-ms 5 \
--nn-pool-size 1 \
--nn-model-filename ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/exp/cpu_jit-epoch-39-avg-6-use-averaged-model-1.pt \
--lang-dir ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/data/lang_bpe_500/ \
--bpe-model-filename ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/data/lang_bpe_500/bpe.model &
echo "Sleep 10 seconds to wait for the server startup"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/run-streaming-test-with-long-waves.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ jobs:
--max-wait-ms 5 \
--nn-pool-size 1 \
--nn-model-filename ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/exp/cpu_jit-epoch-39-avg-6-use-averaged-model-1-torch-${{ matrix.torch }}.pt \
--lang-dir ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/data/lang_bpe_500/ \
--bpe-model-filename ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/data/lang_bpe_500/bpe.model &
echo "Sleep 10 seconds to wait for the server startup"
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/run-streaming-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search", "fast_beam_search_nbest"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down Expand Up @@ -128,6 +128,7 @@ jobs:
--max-wait-ms 5 \
--nn-pool-size 1 \
--nn-model-filename ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/exp/cpu_jit-epoch-39-avg-6-use-averaged-model-1-torch-${{ matrix.torch }}.pt \
--lang-dir ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/data/lang_bpe_500/ \
--bpe-model-filename ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/data/lang_bpe_500/bpe.model &
echo "Sleep 10 seconds to wait for the server startup"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/run-test-windows-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ jobs:
--feature-extractor-pool-size 5 \
--nn-pool-size 1 \
--nn-model-filename ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/cpu_jit.pt \
--lang-dir ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/ \
--bpe-model-filename ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model &
echo "Sleep 10 seconds to wait for the server startup"
sleep 10
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/run-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ jobs:
--feature-extractor-pool-size 5 \
--nn-pool-size 1 \
--nn-model-filename ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/cpu_jit-torch-${{ matrix.torch }}.pt \
--lang-dir ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/ \
--bpe-model-filename ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model &
echo "Sleep 10 seconds to wait for the server startup"
sleep 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.11.0", "1.7.1"]
torchaudio: ["0.11.0", "0.7.2"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
exclude:
- torch: "1.11.0"
torchaudio: "0.7.2"
Expand Down Expand Up @@ -118,6 +118,7 @@ jobs:
--nn-pool-size 1 \
--nn-model-filename ./icefall_asr_wenetspeech_pruned_transducer_stateless5_streaming/exp/cpu_jit_epoch_5_avg_1_torch.${{ matrix.torch }}.pt \
--decoding-method ${{ matrix.decoding }} \
--lang-dir ./icefall_asr_wenetspeech_pruned_transducer_stateless5_streaming/data/lang_char/ \
--token-filename ./icefall_asr_wenetspeech_pruned_transducer_stateless5_streaming/data/lang_char/tokens.txt &
echo "Sleep 10 seconds to wait for the server startup"
sleep 10
Expand Down
41 changes: 35 additions & 6 deletions sherpa/bin/conv_emformer_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from sherpa import (
VALID_FAST_BEAM_SEARCH_METHOD,
Lexicon,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_one_best,
streaming_greedy_search,
)
Expand Down Expand Up @@ -39,9 +41,18 @@ def __init__(
max_states=beam_search_params["max_states"],
max_contexts=beam_search_params["max_contexts"],
)
self.decoding_graph = k2.trivial_graph(
beam_search_params["vocab_size"] - 1, device
)
if decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(beam_search_params["lang_dir"])
self.word_table = lexicon.word_table
lg_filename = beam_search_params["lang_dir"] / "LG.pt"
self.decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
self.decoding_graph.scores *= beam_search_params["ngram_lm_scale"]
else:
self.decoding_graph = k2.trivial_graph(
beam_search_params["vocab_size"] - 1, device
)
self.device = device
self.context_size = beam_search_params["context_size"]
self.beam_search_params = beam_search_params
Expand Down Expand Up @@ -130,6 +141,18 @@ def process(
use_double_scores=True,
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search_nbest_LG":
next_hyp_list = fast_beam_search_nbest_LG(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
rnnt_decoding_config=rnnt_decoding_config,
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
num_paths=self.beam_search_params["num_paths"],
nbest_scale=self.beam_search_params["nbest_scale"],
use_double_scores=True,
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search":
next_hyp_list = fast_beam_search_one_best(
model=model,
Expand All @@ -148,14 +171,20 @@ def process(
s.states = next_state_list[i]
s.hyp = next_hyp_list[i]

def get_texts(self, stream: Stream):
def get_texts(self, stream: Stream) -> str:
"""
Return text after decoding
Args:
stream:
Stream to be processed.
"""
return self.sp.decode(stream.hyp)
if self.decoding_method == "fast_beam_search_nbest_LG":
result = [self.word_table[i] for i in stream.hyp]
result = " ".join(result)
else:
result = self.sp.decode(stream.hyp)

return result


class GreedySearch:
Expand Down Expand Up @@ -284,7 +313,7 @@ def process(
s.decoder_out = next_decoder_out_list[i]
s.hyp = next_hyp_list[i]

def get_texts(self, stream: Stream):
def get_texts(self, stream: Stream) -> str:
"""
Return text after decoding
Args:
Expand Down
43 changes: 36 additions & 7 deletions sherpa/bin/pruned_stateless_emformer_rnnt2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
VALID_FAST_BEAM_SEARCH_METHOD,
Hypotheses,
Hypothesis,
Lexicon,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_one_best,
streaming_greedy_search,
streaming_modified_beam_search,
Expand Down Expand Up @@ -42,9 +44,18 @@ def __init__(
max_states=beam_search_params["max_states"],
max_contexts=beam_search_params["max_contexts"],
)
self.decoding_graph = k2.trivial_graph(
beam_search_params["vocab_size"] - 1, device
)
if decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(beam_search_params["lang_dir"])
self.word_table = lexicon.word_table
lg_filename = beam_search_params["lang_dir"] / "LG.pt"
self.decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
self.decoding_graph.scores *= beam_search_params["ngram_lm_scale"]
else:
self.decoding_graph = k2.trivial_graph(
beam_search_params["vocab_size"] - 1, device
)
self.device = device
self.context_size = beam_search_params["context_size"]
self.beam_search_params = beam_search_params
Expand Down Expand Up @@ -128,6 +139,18 @@ def process(
use_double_scores=True,
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search_nbest_LG":
next_hyp_list = fast_beam_search_nbest_LG(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
rnnt_decoding_config=rnnt_decoding_config,
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
num_paths=self.beam_search_params["num_paths"],
nbest_scale=self.beam_search_params["nbest_scale"],
use_double_scores=True,
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search":
next_hyp_list = fast_beam_search_one_best(
model=model,
Expand All @@ -147,14 +170,20 @@ def process(
s.processed_frames += encoder_out_lens[i]
s.hyp = next_hyp_list[i]

def get_texts(self, stream: Stream):
def get_texts(self, stream: Stream) -> str:
"""
Return text after decoding
Args:
stream:
Stream to be processed.
"""
return self.sp.decode(stream.hyp)
if self.decoding_method == "fast_beam_search_nbest_LG":
result = [self.word_table[i] for i in stream.hyp]
result = " ".join(result)
else:
result = self.sp.decode(stream.hyp)

return result


class GreedySearch:
Expand Down Expand Up @@ -273,7 +302,7 @@ def process(
s.decoder_out = next_decoder_out_list[i]
s.hyp = next_hyp_list[i]

def get_texts(self, stream: Stream):
def get_texts(self, stream: Stream) -> str:
"""
Return text after decoding
Args:
Expand Down Expand Up @@ -362,7 +391,7 @@ def process(
s.states = next_state_list[i]
s.hyps = next_hyps_list[i]

def get_texts(self, stream: Stream):
def get_texts(self, stream: Stream) -> str:
hyp = stream.hyps.get_most_probable(True).ys[
self.beam_search_params["context_size"] :
]
Expand Down
42 changes: 36 additions & 6 deletions sherpa/bin/streaming_pruned_transducer_statelessX/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from sherpa import (
VALID_FAST_BEAM_SEARCH_METHOD,
Lexicon,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_one_best,
streaming_greedy_search,
)
Expand Down Expand Up @@ -39,9 +41,18 @@ def __init__(
max_states=beam_search_params["max_states"],
max_contexts=beam_search_params["max_contexts"],
)
self.decoding_graph = k2.trivial_graph(
beam_search_params["vocab_size"] - 1, device
)
if decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(beam_search_params["lang_dir"])
self.word_table = lexicon.word_table
lg_filename = beam_search_params["lang_dir"] / "LG.pt"
self.decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
self.decoding_graph.scores *= beam_search_params["ngram_lm_scale"]
else:
self.decoding_graph = k2.trivial_graph(
beam_search_params["vocab_size"] - 1, device
)
self.device = device
self.context_size = beam_search_params["context_size"]
self.beam_search_params = beam_search_params
Expand Down Expand Up @@ -140,6 +151,18 @@ def process(
use_double_scores=True,
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search_nbest_LG":
next_hyp_list = fast_beam_search_nbest_LG(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
rnnt_decoding_config=rnnt_decoding_config,
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
num_paths=self.beam_search_params["num_paths"],
nbest_scale=self.beam_search_params["nbest_scale"],
use_double_scores=True,
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search":
next_hyp_list = fast_beam_search_one_best(
model=model,
Expand All @@ -163,17 +186,22 @@ def process(
s.processed_frames += encoder_out_lens[i]
s.hyp = next_hyp_list[i]

def get_texts(self, stream: Stream):
def get_texts(self, stream: Stream) -> str:
"""
Return text after decoding
Args:
stream:
Stream to be processed.
"""
if hasattr(self, "sp"):
if self.decoding_method == "fast_beam_search_nbest_LG":
result = [self.word_table[i] for i in stream.hyp]
result = " ".join(result)
elif hasattr(self, "sp"):
result = self.sp.decode(stream.hyp)
else:
result = [self.token_table[i] for i in stream.hyp]
result = "".join(result).replace("▁", " ")

return result


Expand Down Expand Up @@ -314,7 +342,7 @@ def process(
s.decoder_out = next_decoder_out_list[i]
s.hyp = next_hyp_list[i]

def get_texts(self, stream: Stream):
def get_texts(self, stream: Stream) -> str:
"""
Return text after decoding
Args:
Expand All @@ -330,4 +358,6 @@ def get_texts(self, stream: Stream):
self.token_table[i]
for i in stream.hyp[self.beam_search_params["context_size"] :]
] # noqa
result = "".join(result).replace("▁", " ")

return result
2 changes: 2 additions & 0 deletions sherpa/python/sherpa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from .decode import (
VALID_FAST_BEAM_SEARCH_METHOD,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_one_best,
)
from .lexicon import Lexicon
from .nbest import Nbest
from .utils import add_beam_search_arguments
Loading

0 comments on commit eef8987

Please sign in to comment.