Skip to content

Commit 083c64d

Browse files
authored
[src] CudaDecoder endpointing (#4146, #4101)
* Partial hypotheses * PR comments * Endpointing * PR comments * Neg non-em partial traceback bug fix
1 parent d59f03b commit 083c64d

10 files changed

+545
-135
lines changed

src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ void BatchedThreadedNnet3CudaOnlinePipeline::AllocateAndInitializeData(
8585
cuda_decoder_->SetThreadPoolAndStartCPUWorkers(
8686
thread_pool_.get(), config_.num_decoder_copy_threads);
8787
}
88+
89+
cuda_decoder_->SetOutputFrameShiftInSeconds(
90+
feature_info_->FrameShiftInSeconds() *
91+
config_.compute_opts.frame_subsampling_factor);
92+
8893
n_samples_valid_.resize(max_batch_size_);
8994
n_input_frames_valid_.resize(max_batch_size_);
9095
n_lattice_callbacks_not_done_.store(0);
@@ -217,7 +222,9 @@ void BatchedThreadedNnet3CudaOnlinePipeline::DecodeBatch(
217222
const std::vector<CorrelationID> &corr_ids,
218223
const std::vector<SubVector<BaseFloat>> &wave_samples,
219224
const std::vector<bool> &is_first_chunk,
220-
const std::vector<bool> &is_last_chunk) {
225+
const std::vector<bool> &is_last_chunk,
226+
std::vector<const std::string *> *partial_hypotheses,
227+
std::vector<bool> *end_points) {
221228
nvtxRangePushA("DecodeBatch");
222229
KALDI_ASSERT(corr_ids.size() > 0);
223230
KALDI_ASSERT(corr_ids.size() == wave_samples.size());
@@ -242,9 +249,37 @@ void BatchedThreadedNnet3CudaOnlinePipeline::DecodeBatch(
242249
}
243250
}
244251
int features_frame_stride = d_all_features_.Stride();
252+
if (partial_hypotheses) {
253+
// We're going to have to generate the partial hypotheses
254+
if (word_syms_ == nullptr) {
255+
KALDI_ERR << "You need to set --word-symbol-table to use "
256+
<< "partial hypotheses";
257+
}
258+
cuda_decoder_->AllowPartialHypotheses();
259+
}
260+
if (end_points) cuda_decoder_->AllowEndpointing();
261+
245262
DecodeBatch(corr_ids, d_features_ptrs_, features_frame_stride,
246263
n_input_frames_valid_, d_ivectors_ptrs_, is_first_chunk,
247264
is_last_chunk, &channels_);
265+
266+
if (partial_hypotheses) {
267+
partial_hypotheses->resize(channels_.size());
268+
for (size_t i = 0; i < channels_.size(); ++i) {
269+
PartialHypothesis *partial_hypothesis;
270+
ChannelId ichannel = channels_[i];
271+
cuda_decoder_->GetPartialHypothesis(ichannel, &partial_hypothesis);
272+
(*partial_hypotheses)[i] = &partial_hypothesis->out_str;
273+
}
274+
}
275+
276+
if (end_points) {
277+
end_points->resize(channels_.size());
278+
for (size_t i = 0; i < channels_.size(); ++i) {
279+
ChannelId ichannel = channels_[i];
280+
(*end_points)[i] = cuda_decoder_->EndpointDetected(ichannel);
281+
}
282+
}
248283
}
249284

250285
void BatchedThreadedNnet3CudaOnlinePipeline::DecodeBatch(
@@ -393,10 +428,11 @@ void BatchedThreadedNnet3CudaOnlinePipeline::ReadParametersFromModel() {
393428
input_dim_ = feature.InputFeature()->Dim();
394429
if (use_ivectors_) ivector_dim_ = feature.IvectorFeature()->Dim();
395430
model_frequency_ = feature_info_->GetSamplingFrequency();
396-
BaseFloat frame_shift = feature_info_->FrameShiftInSeconds();
431+
BaseFloat frame_shift_seconds = feature_info_->FrameShiftInSeconds();
397432
input_frames_per_chunk_ = config_.compute_opts.frames_per_chunk;
398-
seconds_per_chunk_ = input_frames_per_chunk_ * frame_shift;
399-
int32 samp_per_frame = static_cast<int>(model_frequency_ * frame_shift);
433+
seconds_per_chunk_ = input_frames_per_chunk_ * frame_shift_seconds;
434+
int32 samp_per_frame =
435+
static_cast<int>(model_frequency_ * frame_shift_seconds);
400436
samples_per_chunk_ = input_frames_per_chunk_ * samp_per_frame;
401437
BatchedStaticNnet3Config nnet3_config;
402438
nnet3_config.compute_opts = config_.compute_opts;
@@ -433,7 +469,8 @@ void BatchedThreadedNnet3CudaOnlinePipeline::FinalizeDecoding(
433469
}
434470

435471
if (dlat.NumStates() > 0) {
436-
if (word_syms_) {
472+
// Used for debugging
473+
if (false && word_syms_) {
437474
CompactLattice best_path_clat;
438475
CompactLatticeShortestPath(dlat, &best_path_clat);
439476

src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,18 @@ class BatchedThreadedNnet3CudaOnlinePipeline {
169169
// If it contains some last chunks for given utterances, it will call
170170
// FinalizeDecoding (building the final lattice, determinize it, etc.)
171171
// asynchronously. The callback for that utterance will then be called
172-
void DecodeBatch(const std::vector<CorrelationID> &corr_ids,
173-
const std::vector<SubVector<BaseFloat>> &wave_samples,
174-
const std::vector<bool> &is_first_chunk,
175-
const std::vector<bool> &is_last_chunk);
172+
//
173+
// If partial_hypotheses is not null, generate and set the current partial
174+
// hypotheses in partial_hypotheses The pointers in partial_hypotheses are
175+
// only valid until the next DecodeBatch call - perform a deep copy if
176+
// necessary
177+
void DecodeBatch(
178+
const std::vector<CorrelationID> &corr_ids,
179+
const std::vector<SubVector<BaseFloat>> &wave_samples,
180+
const std::vector<bool> &is_first_chunk,
181+
const std::vector<bool> &is_last_chunk,
182+
std::vector<const std::string *> *partial_hypotheses = nullptr,
183+
std::vector<bool> *end_point = nullptr);
176184

177185
// Version providing directly the features. Only runs nnet3 & decoder
178186
// Used when we want to provide the final ivectors (offline case)
@@ -207,8 +215,12 @@ class BatchedThreadedNnet3CudaOnlinePipeline {
207215
// Maximum number of seconds per chunk
208216
BaseFloat GetSecondsPerChunk() { return seconds_per_chunk_; }
209217

210-
// Used when debugging. Used to Print the text when a decoding is done
211-
void SetSymbolTable(fst::SymbolTable *word_syms) { word_syms_ = word_syms; }
218+
// Used for partial hypotheses
219+
void SetSymbolTable(const fst::SymbolTable &word_syms) {
220+
word_syms_ = &word_syms;
221+
KALDI_ASSERT(cuda_decoder_);
222+
cuda_decoder_->SetSymbolTable(word_syms);
223+
}
212224

213225
// Wait for all lattice callbacks to complete
214226
// Can be called after DecodeBatch
@@ -355,7 +367,7 @@ class BatchedThreadedNnet3CudaOnlinePipeline {
355367
std::unique_ptr<ThreadPoolLight> thread_pool_;
356368

357369
// Used for debugging
358-
fst::SymbolTable *word_syms_;
370+
const fst::SymbolTable *word_syms_;
359371
// Used when printing to stdout for debugging purposes
360372
std::mutex stdout_m_;
361373
};

src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ class BatchedThreadedNnet3CudaPipeline2 {
245245
void WaitForAllTasks();
246246

247247
// Used for debug
248-
void SetSymbolTable(fst::SymbolTable *word_syms) {
248+
void SetSymbolTable(const fst::SymbolTable &word_syms) {
249249
cuda_online_pipeline_.SetSymbolTable(word_syms);
250250
}
251251

src/cudadecoder/cuda-decoder-common.h

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,56 @@ class DeviceChannelMatrix : public DeviceMatrix<T> {
345345
}
346346
};
347347

348+
// InfoToken contains data that needs to be saved for the backtrack
349+
// in GetBestPath/GetRawLattice
350+
// We don't need the token.cost or token.next_state.
351+
struct __align__(8) InfoToken {
352+
int32 prev_token;
353+
int32 arc_idx;
354+
bool IsUniqueTokenForStateAndFrame() {
355+
// This is a trick used to save space and PCI-E bandwidth (cf
356+
// preprocess_in_place kernel)
357+
// This token is associated with a next_state s, created during the
358+
// processing of frame f.
359+
// If we have multiple tokens associated with the state s in the frame f,
360+
// arc_idx < 0 and -arc_idx is the
361+
// count of such tokens. We will then have to look at another list to read
362+
// the actually arc_idx and prev_token values
363+
// If the current token is the only one, prev_token and arc_idx are valid
364+
// and can be used directly
365+
return (arc_idx >= 0);
366+
}
367+
368+
// Called if this token is linked to others tokens in the same frame (cf
369+
// comments for IsUniqueTokenForStateAndFrame)
370+
// return the {offset,size} pair necessary to list those tokens in the
371+
// extra_prev_tokens list
372+
// They are stored at offset "offset", and we have "size" of those
373+
std::pair<int32, int32> GetSameFSTStateTokensList() {
374+
KALDI_ASSERT(!IsUniqueTokenForStateAndFrame());
375+
376+
return {prev_token, -arc_idx};
377+
}
378+
};
379+
380+
// Device function, used to set a in an InfoToken the [offset,size] related to
381+
// InfoToken.GetSameFSTStateTokensList
382+
__device__ __inline__ void SetSameFSTStateTokensList(int32 offset, int32 size,
383+
InfoToken *info_token) {
384+
// We always have size > 0
385+
*info_token = {offset, -size};
386+
}
387+
388+
// Information about the best path head
389+
// Used by partial hypotheses and endpoiting
390+
struct BestPathTracebackHead {
391+
int index;
392+
CostType relative_cost;
393+
394+
void Reset() { index = -1; }
395+
bool IsSet() { return (index != -1); }
396+
};
397+
348398
// LaneCounters/ChannelCounters
349399
// The counters are all the singular values associated to a lane/channel
350400
// For instance the main queue size. Or the min_cost of all tokens in that
@@ -402,6 +452,7 @@ struct LaneCounters {
402452
int32 main_q_extra_prev_tokens_global_offset;
403453
// Minimum token for that frame
404454
IntegerCostType min_int_cost;
455+
IntegerCostType int_relative_cost;
405456
// Current beam. Can be different from default_beam,
406457
// because of the AdaptiveBeam process, or because of
407458
// ApplyMaxActiveAndReduceBeam
@@ -471,46 +522,6 @@ class CudaDecoderException : public std::exception {
471522
const bool recoverable;
472523
};
473524

474-
// InfoToken contains data that needs to be saved for the backtrack
475-
// in GetBestPath/GetRawLattice
476-
// We don't need the token.cost or token.next_state.
477-
struct __align__(8) InfoToken {
478-
int32 prev_token;
479-
int32 arc_idx;
480-
bool IsUniqueTokenForStateAndFrame() {
481-
// This is a trick used to save space and PCI-E bandwidth (cf
482-
// preprocess_in_place kernel)
483-
// This token is associated with a next_state s, created during the
484-
// processing of frame f.
485-
// If we have multiple tokens associated with the state s in the frame f,
486-
// arc_idx < 0 and -arc_idx is the
487-
// count of such tokens. We will then have to look at another list to read
488-
// the actually arc_idx and prev_token values
489-
// If the current token is the only one, prev_token and arc_idx are valid
490-
// and can be used directly
491-
return (arc_idx >= 0);
492-
}
493-
494-
// Called if this token is linked to others tokens in the same frame (cf
495-
// comments for IsUniqueTokenForStateAndFrame)
496-
// return the {offset,size} pair necessary to list those tokens in the
497-
// extra_prev_tokens list
498-
// They are stored at offset "offset", and we have "size" of those
499-
std::pair<int32, int32> GetSameFSTStateTokensList() {
500-
KALDI_ASSERT(!IsUniqueTokenForStateAndFrame());
501-
502-
return {prev_token, -arc_idx};
503-
}
504-
};
505-
506-
// Device function, used to set a in an InfoToken the [offset,size] related to
507-
// InfoToken.GetSameFSTStateTokensList
508-
__device__ __inline__ void SetSameFSTStateTokensList(int32 offset, int32 size,
509-
InfoToken *info_token) {
510-
// We always have size > 0
511-
*info_token = {offset, -size};
512-
}
513-
514525
// Used to store the index in the GPU hashmap of that FST state
515526
// The hashmap is only generated with the final main queue (post max_active_) of
516527
// each frame
@@ -558,6 +569,25 @@ enum OVERFLOW_TYPE {
558569

559570
enum QUEUE_ID { MAIN_Q = 0, AUX_Q = 1 };
560571

572+
// Used internally to generate partial paths
573+
struct PartialPathArc {
574+
int32 token_idx;
575+
int32 arc_idx;
576+
};
577+
578+
// Partial hypothesis formatted and meant to be used by user
579+
struct PartialHypothesis {
580+
std::vector<int> arc_idx;
581+
std::vector<int> olabel;
582+
std::string out_str;
583+
584+
void clear() {
585+
arc_idx.clear();
586+
olabel.clear();
587+
out_str.clear();
588+
}
589+
};
590+
561591
} // end namespace cuda_decoder
562592
} // end namespace kaldi
563593

src/cudadecoder/cuda-decoder-kernels.cu

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ __global__ void reset_for_frame_and_estimate_cutoff_kernel(
505505
lane_counters->int_cutoff = INT_MAX;
506506
lane_counters->min_int_cost = INT_MAX;
507507
lane_counters->q_overflow = OVERFLOW_NONE;
508+
lane_counters->int_relative_cost = INT_MAX;
508509
lane_counters->aux_q_requested = 0;
509510
lane_counters->main_q_requested = 0;
510511
lane_counters->main_q_local_offset = 0;
@@ -1122,8 +1123,7 @@ __launch_bounds__(KALDI_CUDA_DECODER_LARGEST_1D_BLOCK, 1) __global__
11221123
// One add the final_cost[token.state] before looking for the min
11231124
__global__ void get_best_cost_step1_kernel(DeviceParams cst_dev_params,
11241125
KernelParams params,
1125-
bool use_final_probs,
1126-
CostType fst_zero) {
1126+
bool use_final_probs) {
11271127
const int nlanes = params.nlanes_used;
11281128
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
11291129
LaneCounters *lane_counters = cst_dev_params.d_lanes_counters.lane(ilane);
@@ -1151,7 +1151,7 @@ __global__ void get_best_cost_step1_kernel(DeviceParams cst_dev_params,
11511151
cst_dev_params.d_fst_final_costs[token_state];
11521152
IntegerCostType int_cost_with_final =
11531153
floatToOrderedInt(cost + final_cost);
1154-
if (final_cost != fst_zero) {
1154+
if (final_cost != cst_dev_params.fst_zero) {
11551155
int2 min_and_arg = {int_cost_with_final,
11561156
global_idx}; // sort by cost, put it first
11571157
atomicMinI2(&channel_counters->min_int_cost_and_arg_with_final,
@@ -1169,8 +1169,7 @@ __global__ void get_best_cost_step1_kernel(DeviceParams cst_dev_params,
11691169
// and list them into d_list_final_tokens_in_main_q
11701170
__global__ void get_best_cost_step2_kernel(DeviceParams cst_dev_params,
11711171
KernelParams params,
1172-
bool use_final_probs,
1173-
CostType fst_zero) {
1172+
bool use_final_probs) {
11741173
const int nlanes = params.nlanes_used;
11751174
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
11761175
LaneCounters *lane_counters = cst_dev_params.d_lanes_counters.lane(ilane);
@@ -1218,7 +1217,7 @@ __global__ void get_best_cost_step2_kernel(DeviceParams cst_dev_params,
12181217
cst_dev_params.d_fst_final_costs[token_state];
12191218
const CostType token_cost = orderedIntToFloat(token_int_cost);
12201219
// final_cost == fst_zero -> this state is not final
1221-
token_int_cost = (final_cost != fst_zero)
1220+
token_int_cost = (final_cost != cst_dev_params.fst_zero)
12221221
? floatToOrderedInt(token_cost + final_cost)
12231222
: INT_MAX;
12241223
}
@@ -1401,7 +1400,7 @@ __global__ void fill_hashmap_with_main_q_kernel(DeviceParams cst_dev_params,
14011400
const int32 main_q_end = lane_counters->main_q_narcs_and_end.y;
14021401
int32 min_int_cost = lane_counters->min_int_cost;
14031402
CostType min_cost = orderedIntToFloat(min_int_cost);
1404-
const int32 global_offset = channel_counters->prev_main_q_global_offset;
1403+
const int32 global_offset = lane_counters->main_q_global_offset;
14051404
KALDI_CUDA_DECODER_1D_KERNEL_LOOP(main_q_idx, main_q_end) {
14061405
// Position of considered token in the main_q
14071406
if (main_q_idx < main_q_end) {
@@ -1490,7 +1489,7 @@ __global__ void emitting_preprocess_and_list_extra_prev_tokens_step1_kernel(
14901489
__shared__ typename BlockScan::TempStorage sh_temp_storage;
14911490
const int nlanes = params.nlanes_used;
14921491
KALDI_CUDA_DECODER_BATCH_KERNEL_LOOP(ilane, nlanes) {
1493-
const LaneCounters *lane_counters =
1492+
LaneCounters *lane_counters =
14941493
cst_dev_params.d_lanes_counters.lane(ilane);
14951494
const int32 main_q_end = lane_counters->main_q_narcs_and_end.y;
14961495
// Final cutoff from last ExpandArc execution
@@ -1561,6 +1560,17 @@ __global__ void emitting_preprocess_and_list_extra_prev_tokens_step1_kernel(
15611560
// avoid a new random memory access
15621561
cst_dev_params.d_main_q_arc_offsets.channel(ichannel)[main_q_idx] =
15631562
start;
1563+
1564+
// Saving best cost with final cost, to compute the final_extra_cost
1565+
// It seems like ~5% of all states are final, so the following atomic may be fine
1566+
// if necessary, we could first reduce locally at the CTA level
1567+
const CostType final_cost =
1568+
cst_dev_params.d_fst_final_costs[token_state];
1569+
if(final_cost != cst_dev_params.fst_zero) {
1570+
IntegerCostType token_int_cost_with_final = floatToOrderedInt(orderedIntToFloat(token_int_cost) + final_cost);
1571+
IntegerCostType int_relative_cost = token_int_cost_with_final; // - 0.0f, the min_cost was reset to 0.0f
1572+
atomicMin(&lane_counters->int_relative_cost, int_relative_cost);
1573+
}
15641574
}
15651575
// If that FST state has only one token associated to it, we store
15661576
// that token directly in
@@ -2023,20 +2033,18 @@ void FinalizeProcessNonEmittingKernel(const dim3 &grid, const dim3 &block,
20232033
void GetBestCostStep1Kernel(const dim3 &grid, const dim3 &block,
20242034
const cudaStream_t &st,
20252035
const DeviceParams &cst_dev_params,
2026-
const KernelParams &kernel_params, bool isfinal,
2027-
CostType fst_zero) {
2036+
const KernelParams &kernel_params, bool isfinal) {
20282037
get_best_cost_step1_kernel<<<grid, block, 0, st>>>(
2029-
cst_dev_params, kernel_params, isfinal, fst_zero);
2038+
cst_dev_params, kernel_params, isfinal);
20302039
KALDI_DECODER_CUDA_CHECK_ERROR();
20312040
}
20322041

20332042
void GetBestCostStep2Kernel(const dim3 &grid, const dim3 &block,
20342043
const cudaStream_t &st,
20352044
const DeviceParams &cst_dev_params,
2036-
const KernelParams &kernel_params, bool isfinal,
2037-
CostType fst_zero) {
2045+
const KernelParams &kernel_params, bool isfinal) {
20382046
get_best_cost_step2_kernel<<<grid, block, 0, st>>>(
2039-
cst_dev_params, kernel_params, isfinal, fst_zero);
2047+
cst_dev_params, kernel_params, isfinal);
20402048
KALDI_DECODER_CUDA_CHECK_ERROR();
20412049
}
20422050

0 commit comments

Comments
 (0)