Skip to content

Commit

Permalink
Merge pull request #82 from ezerhouni/add-fast-beam-search-nbest
Browse files Browse the repository at this point in the history
Add fast_beam_search_nbest for streaming pruned rnn-t
  • Loading branch information
ezerhouni authored Jul 26, 2022
2 parents 7f67801 + 751fdda commit adba776
Show file tree
Hide file tree
Showing 25 changed files with 1,017 additions and 429 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Install sherpa
shell: bash
run: |
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/run-streaming-conformer-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down Expand Up @@ -75,6 +75,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Install PyTorch ${{ matrix.torch }}
shell: bash
if: startsWith(matrix.os, 'macos')
Expand All @@ -87,6 +88,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/run-streaming-conv-emformer-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down Expand Up @@ -75,6 +75,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Install PyTorch ${{ matrix.torch }}
shell: bash
if: startsWith(matrix.os, 'macos')
Expand All @@ -87,6 +88,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/run-streaming-test-windows-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
python-version: ["3.7", "3.8", "3.9"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search", "fast_beam_search_nbest"]
steps:
- uses: actions/checkout@v2
with:
Expand Down Expand Up @@ -63,6 +63,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Install kaldifeat
shell: bash
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/run-streaming-test-windows-cuda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Display CMake version
run: |
cmake --version
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/run-streaming-test-with-long-waves.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/run-streaming-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search"]
decoding: ["greedy_search", "modified_beam_search", "fast_beam_search", "fast_beam_search_nbest"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down Expand Up @@ -70,6 +70,7 @@ jobs:
python3 -m torch.utils.collect_env
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
else
Expand All @@ -88,6 +89,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/run-test-aishell.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Install PyTorch ${{ matrix.torch }}
shell: bash
if: startsWith(matrix.os, 'macos')
Expand All @@ -89,6 +90,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/run-test-windows-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Install kaldifeat
shell: bash
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/run-test-windows-cuda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ jobs:
python3 -m pip install k2==1.16.dev20220621+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/index.html
python3 -m torch.utils.collect_env
- name: Display CMake version
run: |
cmake --version
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/run-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Install PyTorch ${{ matrix.torch }}
shell: bash
if: startsWith(matrix.os, 'macos')
Expand All @@ -89,6 +90,7 @@ jobs:
python3 -m torch.utils.collect_env
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
torch: ["1.11.0", "1.7.1"]
torchaudio: ["0.11.0", "0.7.2"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest"]
exclude:
- torch: "1.11.0"
torchaudio: "0.7.2"
Expand Down Expand Up @@ -71,6 +71,8 @@ jobs:
pip install torchaudio==${{ matrix.torchaudio }}
fi
python3 -m torch.utils.collect_env
- name: Install PyTorch ${{ matrix.torch }}
shell: bash
if: startsWith(matrix.os, 'macos')
Expand All @@ -80,6 +82,7 @@ jobs:
python3 -m pip install torch==${{ matrix.torch }} torchaudio==${{ matrix.torchaudio }} numpy -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==1.16.dev20220621+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/index.html
python3 -m torch.utils.collect_env
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
Expand Down Expand Up @@ -107,6 +110,7 @@ jobs:
run: |
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
./sherpa/bin/streaming_pruned_transducer_statelessX/streaming_server.py \
--port 6006 \
--max-batch-size 50 \
Expand Down
108 changes: 71 additions & 37 deletions sherpa/bin/conv_emformer_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,47 @@
import torch
from stream import Stream, stack_states, unstack_states

from sherpa import fast_beam_search_one_best, streaming_greedy_search
from sherpa import (
VALID_FAST_BEAM_SEARCH_METHOD,
fast_beam_search_nbest,
fast_beam_search_one_best,
streaming_greedy_search,
)


class FastBeamSearch:
def __init__(
self,
vocab_size: int,
context_size: int,
beam: int,
max_states: int,
max_contexts: int,
beam_search_params: dict,
device: torch.device,
):
"""
Args:
vocab_size:
Vocabularize of the BPE
context_size:
Context size of the RNN-T decoder model.
beam:
The beam for fast_beam_search decoding.
max_states:
The max_states for fast_beam_search decoding.
max_contexts:
The max_contexts for fast_beam_search decoding.
beam_search_params
Dictionary containing all the parameters for beam search.
device:
Device on which the computation will occur
"""

decoding_method = beam_search_params["decoding_method"]
assert (
decoding_method in VALID_FAST_BEAM_SEARCH_METHOD
), f"{decoding_method} is not a valid search method"

self.decoding_method = decoding_method
self.rnnt_decoding_config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
vocab_size=beam_search_params["vocab_size"],
decoder_history_len=beam_search_params["context_size"],
beam=beam_search_params["beam"],
max_states=beam_search_params["max_states"],
max_contexts=beam_search_params["max_contexts"],
)
self.decoding_graph = k2.trivial_graph(
beam_search_params["vocab_size"] - 1, device
)
self.decoding_graph = k2.trivial_graph(vocab_size - 1, device)
self.device = device
self.context_size = context_size
self.context_size = beam_search_params["context_size"]
self.beam_search_params = beam_search_params

def init_stream(self, stream: Stream):
"""
Expand Down Expand Up @@ -115,13 +118,30 @@ def process(
)

processed_lens = (num_processed_frames >> 2) + encoder_out_lens
next_hyp_list = fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
rnnt_decoding_config=rnnt_decoding_config,
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
)
if self.decoding_method == "fast_beam_search_nbest":
next_hyp_list = fast_beam_search_nbest(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
rnnt_decoding_config=rnnt_decoding_config,
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
num_paths=self.beam_search_params["num_paths"],
nbest_scale=self.beam_search_params["nbest_scale"],
use_double_scores=True,
temperature=self.beam_search_params["temperature"],
)
elif self.decoding_method == "fast_beam_search":
next_hyp_list = fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
rnnt_decoding_config=rnnt_decoding_config,
rnnt_decoding_streams_list=rnnt_decoding_streams_list,
)
else:
raise NotImplementedError(
f"{self.decoding_method} is not implemented"
)

next_state_list = unstack_states(next_states)
for i, s in enumerate(stream_list):
Expand All @@ -139,24 +159,34 @@ def get_texts(self, stream: Stream):


class GreedySearch:
def __init__(self, model: "RnntConvEmformerModel", device: torch.device):
def __init__(
self,
model: "RnntConvEmformerModel",
beam_search_params: dict,
device: torch.device,
):
"""
Args:
model:
RNN-T model decoder model
beam_search_params:
Dictionary containing all the parameters for beam search.
device:
Device on which the computation will occur
"""

self.blank_id = model.blank_id
self.context_size = model.context_size
self.device = device
self.beam_search_params = beam_search_params
self.device = device

decoder_input = torch.tensor(
[[self.blank_id] * self.context_size],
[
[self.beam_search_params["blank_id"]]
* self.beam_search_params["context_size"]
],
device=self.device,
dtype=torch.int64,
)

initial_decoder_out = model.decoder_forward(decoder_input)
self.initial_decoder_out = model.forward_decoder_proj(
initial_decoder_out.squeeze(1)
Expand All @@ -167,7 +197,9 @@ def init_stream(self, stream: Stream):
Attributes to add to each stream
"""
stream.decoder_out = self.initial_decoder_out
stream.hyp = [self.blank_id] * self.context_size
stream.hyp = [
self.beam_search_params["blank_id"]
] * self.beam_search_params["context_size"]

@torch.no_grad()
def process(
Expand Down Expand Up @@ -259,4 +291,6 @@ def get_texts(self, stream: Stream):
stream:
Stream to be processed.
"""
return self.sp.decode(stream.hyp[self.context_size :])
return self.sp.decode(
stream.hyp[self.beam_search_params["context_size"] :]
)
Loading

0 comments on commit adba776

Please sign in to comment.