Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

[WIP] 2-state HMM topo as an alternative to CTC topo #126

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Decoding changes for 2-state HMM topo
  • Loading branch information
pzelasko committed Mar 15, 2021
commit ee992eb63d34b464b419fae28741510969f4bc7f
19 changes: 10 additions & 9 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from snowfall.common import setup_logger
from snowfall.decoding.graph import compile_HLG
from snowfall.models import AcousticModel
from snowfall.models.transformer import Transformer
from snowfall.models.conformer import Conformer
from snowfall.training.ctc_graph import build_ctc_topo
from snowfall.models.transformer import Transformer
from snowfall.training.hmm_topo import build_hmm_topo_2state
from snowfall.training.mmi_graph import create_bigram_phone_lm
from snowfall.training.mmi_graph import get_phone_symbols

Expand Down Expand Up @@ -218,7 +218,8 @@ def main():
P = create_bigram_phone_lm(phone_ids)

phone_ids_with_blank = [0] + phone_ids
ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))
# H = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))
H = build_hmm_topo_2state(phone_ids_with_blank)

logging.debug("About to load model")
# Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
Expand All @@ -235,15 +236,15 @@ def main():
num_features=40,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
num_classes=2 * len(phone_ids) + 2, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
else:
model = Conformer(
num_features=40,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
num_classes=2 * len(phone_ids) + 2, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)

Expand Down Expand Up @@ -277,10 +278,10 @@ def main():
first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table)
first_word_disambig_id = find_first_disambig_symbol(symbol_table)
HLG = compile_HLG(L=L,
G=G,
H=ctc_topo,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)
G=G,
H=H,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)
torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
else:
logging.debug("Loading pre-compiled HLG")
Expand Down