Skip to content

Commit 88e7b85

Browse files
[src] Add a new binary for chain e2e (kaldi-asr#3945)
This is for use in an external project.
1 parent 673c6fc commit 88e7b85

File tree

2 files changed

+123
-1
lines changed

2 files changed

+123
-1
lines changed

src/chainbin/Makefile

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ BINFILES = chain-est-phone-lm chain-get-supervision chain-make-den-fst \
1111
nnet3-chain-shuffle-egs nnet3-chain-subset-egs \
1212
nnet3-chain-acc-lda-stats nnet3-chain-train nnet3-chain-compute-prob \
1313
nnet3-chain-combine nnet3-chain-normalize-egs \
14-
nnet3-chain-e2e-get-egs nnet3-chain-compute-post
14+
nnet3-chain-e2e-get-egs nnet3-chain-compute-post \
15+
chain-make-num-fst-e2e
1516

1617

1718
OBJFILES =
+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// chainbin/chain-make-num-fst-e2e.cc
2+
3+
// Copyright 2020 Yiwen Shao
4+
5+
// See ../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16+
// MERCHANTABLITY OR NON-INFRINGEMENT.
17+
// See the Apache 2 License for the specific language governing permissions and
18+
// limitations under the License.
19+
20+
/** @brief Converts fsts (containing transition-ids) to fsts (containing pdf-ids + 1).
21+
*/
22+
#include "base/kaldi-common.h"
23+
#include "gmm/am-diag-gmm.h"
24+
#include "hmm/transition-model.h"
25+
#include "hmm/hmm-utils.h"
26+
#include "util/common-utils.h"
27+
#include "fst/fstlib.h"
28+
29+
namespace kaldi {
30+
31+
bool FstTransitionToPdfPlusOne(const fst::StdVectorFst &fst_transition,
32+
const TransitionModel &trans_model,
33+
fst::StdVectorFst *fst_pdf) {
34+
fst::StdVectorFst fst_tmp(fst_transition);
35+
fst::RemoveEpsLocal(&fst_tmp);
36+
fst::RmEpsilon(&fst_tmp);
37+
// first change labels to pdf-id + 1
38+
int32 num_states = fst_tmp.NumStates();
39+
for (int32 state = 0; state < num_states; state++) {
40+
for (fst::MutableArcIterator<fst::StdVectorFst> aiter(&fst_tmp, state);
41+
!aiter.Done(); aiter.Next()) {
42+
const fst::StdArc &arc = aiter.Value();
43+
if (arc.ilabel == 0) {
44+
KALDI_WARN << "Utterance rejected due to eps on input label";
45+
return false;
46+
}
47+
KALDI_ASSERT(arc.ilabel != 0);
48+
fst::StdArc arc2(arc);
49+
arc2.ilabel = arc2.olabel = trans_model.TransitionIdToPdf(arc.ilabel) + 1;
50+
aiter.SetValue(arc2);
51+
}
52+
}
53+
*fst_pdf = fst_tmp;
54+
return true;
55+
}
56+
57+
bool AddWeightToFst(const fst::StdVectorFst &normalization_fst,
58+
fst::StdVectorFst *fst) {
59+
// Note: by default, 'Compose' will call 'Connect', so if the
60+
// resulting FST is not connected, it will end up empty.
61+
fst::StdVectorFst composed_fst;
62+
fst::Compose(*fst, normalization_fst,
63+
&composed_fst);
64+
*fst = composed_fst;
65+
if (composed_fst.NumStates() == 0)
66+
return false;
67+
return true;
68+
}
69+
70+
}
71+
72+
int main(int argc, char *argv[]) {
73+
using namespace kaldi;
74+
typedef kaldi::int32 int32;
75+
try {
76+
const char *usage =
77+
"Converts chain e2e numerator fst (containing transition-ids) to fst (containing pdf-ids+1, \n"
78+
"and composed by the normalization fst) \n"
79+
"Usage: chain-make-num-fst-e2e [options] <model> <normalization-fst>\n"
80+
"<trainsition-fst-rspecifier> <pdf-fst-wspecifier>\n"
81+
"e.g.: \n"
82+
" chain-make-num-fst-e2e 1.mdl ark:1.fst ark,t:-\n";
83+
ParseOptions po(usage);
84+
85+
po.Read(argc, argv);
86+
87+
if (po.NumArgs() != 4) {
88+
po.PrintUsage();
89+
exit(1);
90+
}
91+
92+
std::string model_filename = po.GetArg(1),
93+
normalization_fst_rxfilename = po.GetArg(2),
94+
fsts_rspecifier = po.GetArg(3),
95+
fsts_wspecifier = po.GetArg(4);
96+
97+
TransitionModel trans_model;
98+
ReadKaldiObject(model_filename, &trans_model);
99+
100+
fst::StdVectorFst normalization_fst;
101+
ReadFstKaldi(normalization_fst_rxfilename, &normalization_fst);
102+
103+
SequentialTableReader<fst::VectorFstHolder> fsts_reader(fsts_rspecifier);
104+
TableWriter<fst::VectorFstHolder> fsts_writer(fsts_wspecifier);
105+
106+
int32 num_done = 0;
107+
for (; !fsts_reader.Done(); fsts_reader.Next()) {
108+
std::string key = fsts_reader.Key();
109+
fst::VectorFst<fst::StdArc> fst_transition(fsts_reader.Value());
110+
fst::StdVectorFst fst_pdf;
111+
FstTransitionToPdfPlusOne(fst_transition, trans_model, &fst_pdf);
112+
AddWeightToFst(normalization_fst, &fst_pdf);
113+
fsts_writer.Write(key, fst_pdf);
114+
num_done++;
115+
}
116+
KALDI_LOG << "Converted " << num_done << " Fsts with transition-id to Fsts with pdf-id and normalized.";
117+
} catch(const std::exception &e) {
118+
std::cerr << e.what();
119+
return -1;
120+
}
121+
}

0 commit comments

Comments
 (0)