Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KenLM-based decoder, Add Dictionary decoder #37

Merged
merged 17 commits into from
Nov 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,5 @@ ENV/
*.exe
*.out
*.app

pytorch_ctc/ctcdecode.xcodeproj/
32 changes: 20 additions & 12 deletions pytorch_ctc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ def decode(self, probs, seq_len=None):
scores = torch.FloatTensor(self._top_paths, batch_size)
out_seq_len = torch.IntTensor(self._top_paths, batch_size)
alignments = torch.IntTensor(self._top_paths, batch_size, max_seq_len)
char_probs = torch.FloatTensor(self._top_paths, batch_size, max_seq_len)

result = ctc._ctc_beam_decode(self._decoder, self._decoder_type, probs, seq_len, output, scores, out_seq_len,
alignments)
alignments, char_probs)

return output, scores, out_seq_len, alignments
return output, scores, out_seq_len, alignments, char_probs


class BaseScorer(object):
Expand All @@ -72,12 +73,19 @@ def __init__(self):
self._scorer = ctc._get_base_scorer()


class DictScorer(BaseScorer):
def __init__(self, labels, trie_path, blank_index=0, space_index=28):
super(DictScorer, self).__init__()
self._scorer_type = 1
self._scorer = ctc._get_dict_scorer(labels, len(labels), space_index, blank_index, trie_path.encode())


class KenLMScorer(BaseScorer):
def __init__(self, labels, lm_path, trie_path, blank_index=0, space_index=28):
super(KenLMScorer, self).__init__()
if ctc._kenlm_enabled() != 1:
raise ImportError("pytorch-ctc not compiled with KenLM support.")
self._scorer_type = 1
self._scorer_type = 2
self._scorer = ctc._get_kenlm_scorer(labels, len(labels), space_index, blank_index, lm_path.encode(),
trie_path.encode())

Expand All @@ -94,10 +102,6 @@ def set_word_weight(self, weight):
if weight is not None:
ctc._set_kenlm_scorer_wc_weight(self._scorer, weight)

def set_valid_word_weight(self, weight):
if weight is not None:
ctc._set_kenlm_scorer_vwc_weight(self._scorer, weight)


class CTCBeamDecoder(BaseCTCBeamDecoder):
def __init__(self, scorer, labels, top_paths=1, beam_width=10, blank_index=0, space_index=28):
Expand All @@ -112,11 +116,15 @@ def set_label_selection_parameters(self, label_size=0, label_margin=-1):
ctc._set_label_selection_parameters(self._decoder, label_size, label_margin)


def generate_lm_dict(dictionary_path, kenlm_path, output_path, labels, blank_index=0, space_index=28):
if ctc._kenlm_enabled() != 1:
def generate_lm_dict(dictionary_path, output_path, labels, kenlm_path=None, blank_index=0, space_index=28):
if kenlm_path is not None and ctc._kenlm_enabled() != 1:
raise ImportError("pytorch-ctc not compiled with KenLM support.")
result = ctc._generate_lm_dict(labels, len(labels), blank_index, space_index, kenlm_path.encode(),
dictionary_path.encode(), output_path.encode())

result = None
if kenlm_path is not None:
result = ctc._generate_lm_dict(labels, len(labels), blank_index, space_index, kenlm_path.encode(),
dictionary_path.encode(), output_path.encode())
else:
result = ctc._generate_dict(labels, len(labels), blank_index, space_index,
dictionary_path.encode(), output_path.encode())
if result != 0:
raise ValueError("Error encountered generating dictionary")
103 changes: 85 additions & 18 deletions pytorch_ctc/src/cpu_binding.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#ifdef TORCH_BINDING
#include <iostream>
#include "ctc_beam_entry.h"
#include "ctc_beam_scorer.h"
Expand All @@ -7,6 +8,7 @@
#include "util/status.h"
#include "TH.h"
#include "cpu_binding.h"
#include "ctc_beam_scorer_dict.h"

#ifdef INCLUDE_KENLM
#include "ctc_beam_scorer_klm.h"
Expand All @@ -17,6 +19,8 @@
namespace pytorch {
using pytorch::ctc::Labels;
using pytorch::ctc::Status;
using pytorch::ctc::DictBeamScorer;
using pytorch::ctc::ctc_beam_search::DictBeamState;

#ifdef INCLUDE_KENLM
using pytorch::ctc::KenLMBeamScorer;
Expand All @@ -40,7 +44,7 @@ namespace pytorch {
return full_score_return.prob;
}

int generate_dictionary(Labels& labels, const char* kenlm_path, const char* vocab_path, const char* trie_path) {
int generate_klm_dict_trie(Labels& labels, const char* kenlm_path, const char* vocab_path, const char* trie_path) {
lm::ngram::Config config;
config.load_method = util::POPULATE_OR_READ;
Model* model = lm::ngram::LoadVirtual(kenlm_path, config);
Expand Down Expand Up @@ -76,6 +80,35 @@ namespace pytorch {
}
#endif

int generate_dict_trie(Labels& labels, const char* vocab_path, const char* trie_path) {
std::ifstream ifs;
ifs.open(vocab_path, std::ifstream::in);

ctc::TrieNode root(labels.GetSize());

if (!ifs.is_open()) {
std::cout << "unable to open vocabulary" << std::endl;
return -1;
}

std::ofstream ofs;
ofs.open(trie_path);

std::string word;
int i = 0;
while (ifs >> word) {
std::wstring wide_word;
utf8::utf8to16(word.begin(), word.end(), std::back_inserter(wide_word));
root.Insert(wide_word.c_str(), [&labels](wchar_t c) {
return labels.GetLabel(c);
}, i++, 0);
}
root.WriteToStream(ofs);
ifs.close();
ofs.close();
return 0;
}

extern "C"
{
void* get_kenlm_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
Expand All @@ -90,8 +123,18 @@ namespace pytorch {
}

void free_kenlm_scorer(void* kenlm_scorer) {
#ifdef INCLUDE_KENLM
ctc::KenLMBeamScorer* beam_scorer = static_cast<ctc::KenLMBeamScorer*>(kenlm_scorer);
delete beam_scorer;
#endif
return;
}

void* get_dict_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
const char* trie_path) {
Labels* labels = new Labels(label_str, labels_size, blank_index, space_index);
ctc::DictBeamScorer *beam_scorer = new ctc::DictBeamScorer(labels, trie_path);
return static_cast<void*>(beam_scorer);
}

void set_kenlm_scorer_lm_weight(void *scorer, float weight) {
Expand All @@ -108,13 +151,6 @@ namespace pytorch {
#endif
}

void set_kenlm_scorer_vwc_weight(void *scorer, float weight) {
#ifdef INCLUDE_KENLM
ctc::KenLMBeamScorer *beam_scorer = static_cast<ctc::KenLMBeamScorer *>(scorer);
beam_scorer->SetValidWordCountWeight(weight);
#endif
}

void set_label_selection_parameters(void *decoder, int label_selection_size, float label_selection_margin) {
ctc::CTCBeamSearchDecoder<> *beam_decoder = static_cast<ctc::CTCBeamSearchDecoder<> *>(decoder);
beam_decoder->SetLabelSelectionParameters(label_selection_size, label_selection_margin);
Expand All @@ -128,12 +164,19 @@ namespace pytorch {
void* get_ctc_beam_decoder(int num_classes, int top_paths, int beam_width, int blank_index, void *scorer, DecodeType type) {
switch (type) {
case CTC:
{
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *beam_scorer = static_cast<ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *>(scorer);
ctc::CTCBeamSearchDecoder<> *decoder = new ctc::CTCBeamSearchDecoder<>
(num_classes, beam_width, beam_scorer, blank_index);
return static_cast<void *>(decoder);
}
{
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *beam_scorer = static_cast<ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *>(scorer);
ctc::CTCBeamSearchDecoder<> *decoder = new ctc::CTCBeamSearchDecoder<>
(num_classes, beam_width, beam_scorer, blank_index);
return static_cast<void *>(decoder);
}
case CTC_DICT:
{
ctc::DictBeamScorer *beam_scorer = static_cast<ctc::DictBeamScorer *>(scorer);
ctc::CTCBeamSearchDecoder<DictBeamState> *decoder = new ctc::CTCBeamSearchDecoder<DictBeamState>
(num_classes, beam_width, beam_scorer, blank_index);
return static_cast<void *>(decoder);
}
#ifdef INCLUDE_KENLM
case CTC_KENLM:
{
Expand All @@ -148,7 +191,7 @@ namespace pytorch {
}

int ctc_beam_decode(void *void_decoder, DecodeType type, THFloatTensor *th_probs, THIntTensor *th_seq_len, THIntTensor *th_output,
THFloatTensor *th_scores, THIntTensor *th_out_len, THIntTensor *th_alignments)
THFloatTensor *th_scores, THIntTensor *th_out_len, THIntTensor *th_alignments, THFloatTensor *th_char_probs)
{
const int64_t max_time = THFloatTensor_size(th_probs, 0);
const int64_t batch_size = THFloatTensor_size(th_probs, 1);
Expand Down Expand Up @@ -182,6 +225,10 @@ namespace pytorch {
for (ctc::CTCDecoder::Output& alignment : alignments) {
alignment.resize(batch_size);
}
std::vector<ctc::CTCDecoder::CharProbability> char_probs(top_paths);
for (ctc::CTCDecoder::CharProbability& char_ : char_probs) {
char_.resize(batch_size);
}
float score[batch_size][top_paths];
memset(score, 0.0, batch_size*top_paths*sizeof(int));
Eigen::Map<Eigen::MatrixXf> *scores;
Expand All @@ -192,7 +239,17 @@ namespace pytorch {
{
ctc::CTCBeamSearchDecoder<> *decoder = static_cast<ctc::CTCBeamSearchDecoder<> *>(void_decoder);
scores = new Eigen::Map<Eigen::MatrixXf>(&score[0][0], batch_size, decoder->GetBeamWidth());
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments);
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments, &char_probs);
if (!stat.ok()) {
return 0;
}
}
break;
case CTC_DICT:
{
ctc::CTCBeamSearchDecoder<DictBeamState> *decoder = static_cast<ctc::CTCBeamSearchDecoder<DictBeamState> *>(void_decoder);
scores = new Eigen::Map<Eigen::MatrixXf>(&score[0][0], batch_size, decoder->GetBeamWidth());
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments, &char_probs);
if (!stat.ok()) {
return 0;
}
Expand All @@ -203,7 +260,7 @@ namespace pytorch {
{
ctc::CTCBeamSearchDecoder<KenLMBeamState> *decoder = static_cast<ctc::CTCBeamSearchDecoder<KenLMBeamState> *>(void_decoder);
scores = new Eigen::Map<Eigen::MatrixXf>(&score[0][0], batch_size, decoder->GetBeamWidth());
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments);
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments, &char_probs);
if (!stat.ok()) {
return 0;
}
Expand All @@ -219,6 +276,7 @@ namespace pytorch {
for (int b=0; b < batch_size; ++b) {
auto& p_batch = outputs[p][b];
auto& alignment_batch = alignments[p][b];
auto& char_prob_batch = char_probs[p][b];
int64_t num_decoded = p_batch.size();

max_decoded = std::max(max_decoded, num_decoded);
Expand All @@ -227,6 +285,7 @@ namespace pytorch {
// TODO: this could be more efficient (significant pointer arithmetic every time currently)
THIntTensor_set3d(th_output, p, b, t, p_batch[t]);
THIntTensor_set3d(th_alignments, p, b, t, alignment_batch[t]);
THFloatTensor_set3d(th_char_probs, p, b, t, char_prob_batch[t]);
THFloatTensor_set2d(th_scores, p, b, (*scores)(b, p));
}
}
Expand All @@ -239,12 +298,19 @@ namespace pytorch {
const char* lm_path, const char* dictionary_path, const char* output_path) {
#ifdef INCLUDE_KENLM
Labels labels(label_str, size, blank_index, space_index);
return generate_dictionary(labels, lm_path, dictionary_path, output_path);
return generate_klm_dict_trie(labels, lm_path, dictionary_path, output_path);
#else
return -1;
#endif
}

int generate_dict(const wchar_t* label_str, int size, int blank_index, int space_index,
const char* dictionary_path, const char* output_path) {
Labels labels(label_str, size, blank_index, space_index);
return generate_dict_trie(labels, dictionary_path, output_path);
}


int kenlm_enabled() {
#ifdef INCLUDE_KENLM
return 1;
Expand All @@ -254,3 +320,4 @@ namespace pytorch {
}
}
}
#endif
15 changes: 11 additions & 4 deletions pytorch_ctc/src/cpu_binding.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
typedef enum {
CTC,
CTC_DICT,
CTC_KENLM
} DecodeType ;


/* scorers */
int kenlm_enabled();

void* get_dict_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
const char* trie_path);
void* get_kenlm_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
const char* lm_path, const char* trie_path);
void free_kenlm_scorer(void* kenlm_scorer);

void set_kenlm_scorer_lm_weight(void *scorer, float weight);
void set_kenlm_scorer_wc_weight(void *scorer, float weight);
void set_kenlm_scorer_vwc_weight(void *scorer, float weight);
// void set_kenlm_scorer_vwc_weight(void *scorer, float weight);
void set_label_selection_parameters(void *decoder, int label_selection_size, float label_selection_margin);
void* get_base_scorer();

Expand All @@ -23,11 +27,14 @@ void* get_ctc_beam_decoder(int num_classes, int top_paths, int beam_width, int b


/* run decoding */
int ctc_beam_decode(void *decoder, DecodeType type,
THFloatTensor *probs, THIntTensor *seq_len, THIntTensor *output,
THFloatTensor *scores, THIntTensor *th_out_len, THIntTensor *th_alignments);
int ctc_beam_decode(void *decoder, DecodeType type, THFloatTensor *probs,
THIntTensor *seq_len, THIntTensor *output, THFloatTensor *scores,
THIntTensor *th_out_len, THIntTensor *th_alignments, THFloatTensor *char_probs);


/* utilities */
int generate_lm_dict(const wchar_t* labels, int size, int blank_index, int space_index,
const char* lm_path, const char* dictionary_path, const char* output_path);

int generate_dict(const wchar_t* labels, int size, int blank_index, int space_index,
const char* dictionary_path, const char* output_path);
15 changes: 11 additions & 4 deletions pytorch_ctc/src/ctc_beam_entry.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,9 @@ struct BeamEntry {
}
std::vector<int> LabelSeq() const {
std::vector<int> labels;
int prev_label = -1;
const BeamEntry* c = this;
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
labels.push_back(c->label);
prev_label = c->label;
c = c->parent;
}
std::reverse(labels.begin(), labels.end());
Expand All @@ -105,17 +103,26 @@ struct BeamEntry {

std::vector<int> TimeStepSeq() const {
std::vector<int> time_steps;
int prev_label = -1;
const BeamEntry *c = this;
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
time_steps.push_back(c->time_step);
prev_label = c->label;
c = c->parent;
}
std::reverse(time_steps.begin(), time_steps.end());
return time_steps;
}

std::vector<float> CharProbSeq() const {
std::vector<float> probs;
const BeamEntry *c = this;
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
probs.push_back(c->newp.total);
c = c->parent;
}
std::reverse(probs.begin(), probs.end());
return probs;
}

BeamEntry<CTCBeamState>* parent;
int label;
int time_step;
Expand Down
Loading