Skip to content

Commit

Permalink
Add offline_asr
Browse files Browse the repository at this point in the history
  • Loading branch information
ezerhouni committed Jul 26, 2022
1 parent 5659d4b commit 33aa226
Showing 1 changed file with 17 additions and 36 deletions.
53 changes: 17 additions & 36 deletions sherpa/bin/pruned_transducer_statelessX/offline_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@
import torchaudio
from beam_search import GreedySearchOffline, ModifiedBeamSearchOffline

from sherpa import RnntConformerModel
from sherpa import RnntConformerModel, add_beam_search_arguments


def get_args():
beam_search_parser = add_beam_search_arguments()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
parents=[beam_search_parser],
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument(
Expand Down Expand Up @@ -136,26 +138,6 @@ def get_args():
""",
)

parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Decoding method to use. Currently, only greedy_search and
modified_beam_search are implemented.
""",
)

parser.add_argument(
"--num-active-paths",
type=int,
default=4,
help="""Used only when decoding_method is modified_beam_search.
It specifies number of active paths for each utterance. Due to
merging paths with identical token sequences, the actual number
may be less than "num_active_paths".
""",
)

parser.add_argument(
"--sample-rate",
type=int,
Expand All @@ -173,7 +155,10 @@ def get_args():
"The sample rate has to equal to `--sample-rate`.",
)

return parser.parse_args()
return (
parser.parse_args(),
beam_search_parser.parse_known_args()[0],
)


def read_sound_files(
Expand Down Expand Up @@ -207,10 +192,10 @@ def __init__(
nn_model_filename: str,
bpe_model_filename: Optional[str],
token_filename: Optional[str],
decoding_method: str,
num_active_paths: int,
sample_rate: int = 16000,
device: Union[str, torch.device] = "cpu",
beam_search_params: dict = {},
):
"""
Args:
Expand All @@ -222,9 +207,6 @@ def __init__(
token_filename:
Path to tokens.txt. If it is None, you have to provide
`bpe_model_filename`.
decoding_method:
The decoding method to use. Currently, only greedy_search and
modified_beam_search are implemented.
num_active_paths:
Used only when decoding_method is modified_beam_search.
It specifies number of active paths for each utterance. Due to
Expand All @@ -234,6 +216,8 @@ def __init__(
Expected sample rate of the feature extractor.
device:
The device to use for computation.
beam_search_params:
Dictionary containing all the parameters for beam search.
"""
self.model = RnntConformerModel(
filename=nn_model_filename,
Expand All @@ -252,15 +236,11 @@ def __init__(
device=device,
)

assert decoding_method in (
"greedy_search",
"modified_beam_search",
), decoding_method

decoding_method = beam_search_params["decoding_method"]
if decoding_method == "greedy_search":
self.beam_search = GreedySearchOffline()
elif decoding_method == "modified_beam_search":
self.beam_search = ModifiedBeamSearchOffline(num_active_paths)
self.beam_search = ModifiedBeamSearchOffline(beam_search_params)
else:
raise ValueError(
f"Decoding method {decoding_method} is not supported."
Expand Down Expand Up @@ -328,17 +308,18 @@ def decode_waves(self, waves: List[torch.Tensor]) -> List[List[str]]:

@torch.no_grad()
def main():
args = get_args()
args, beam_search_parser = get_args()
beam_search_params = vars(beam_search_parser)
logging.info(vars(args))

nn_model_filename = args.nn_model_filename
bpe_model_filename = args.bpe_model_filename
token_filename = args.token_filename
decoding_method = args.decoding_method
num_active_paths = args.num_active_paths
sample_rate = args.sample_rate
sound_files = args.sound_files

decoding_method = beam_search_params["decoding_method"]
assert decoding_method in (
"greedy_search",
"modified_beam_search",
Expand Down Expand Up @@ -374,10 +355,10 @@ def main():
nn_model_filename=nn_model_filename,
bpe_model_filename=bpe_model_filename,
token_filename=token_filename,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
sample_rate=sample_rate,
device=device,
beam_search_params=beam_search_params,
)

waves = read_sound_files(
Expand Down

0 comments on commit 33aa226

Please sign in to comment.