Skip to content

Commit

Permalink
Merge pull request #37 from ryanleary/debug
Browse files Browse the repository at this point in the history
Fix KenLM-based decoder, Add Dictionary decoder
  • Loading branch information
ryanleary authored Nov 7, 2017
2 parents 60dbbaf + b1d20f9 commit dd0835e
Show file tree
Hide file tree
Showing 19 changed files with 3,041 additions and 343 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,5 @@ ENV/
*.exe
*.out
*.app

pytorch_ctc/ctcdecode.xcodeproj/
32 changes: 20 additions & 12 deletions pytorch_ctc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ def decode(self, probs, seq_len=None):
scores = torch.FloatTensor(self._top_paths, batch_size)
out_seq_len = torch.IntTensor(self._top_paths, batch_size)
alignments = torch.IntTensor(self._top_paths, batch_size, max_seq_len)
char_probs = torch.FloatTensor(self._top_paths, batch_size, max_seq_len)

result = ctc._ctc_beam_decode(self._decoder, self._decoder_type, probs, seq_len, output, scores, out_seq_len,
alignments)
alignments, char_probs)

return output, scores, out_seq_len, alignments
return output, scores, out_seq_len, alignments, char_probs


class BaseScorer(object):
Expand All @@ -72,12 +73,19 @@ def __init__(self):
self._scorer = ctc._get_base_scorer()


class DictScorer(BaseScorer):
def __init__(self, labels, trie_path, blank_index=0, space_index=28):
super(DictScorer, self).__init__()
self._scorer_type = 1
self._scorer = ctc._get_dict_scorer(labels, len(labels), space_index, blank_index, trie_path.encode())


class KenLMScorer(BaseScorer):
def __init__(self, labels, lm_path, trie_path, blank_index=0, space_index=28):
super(KenLMScorer, self).__init__()
if ctc._kenlm_enabled() != 1:
raise ImportError("pytorch-ctc not compiled with KenLM support.")
self._scorer_type = 1
self._scorer_type = 2
self._scorer = ctc._get_kenlm_scorer(labels, len(labels), space_index, blank_index, lm_path.encode(),
trie_path.encode())

Expand All @@ -94,10 +102,6 @@ def set_word_weight(self, weight):
if weight is not None:
ctc._set_kenlm_scorer_wc_weight(self._scorer, weight)

def set_valid_word_weight(self, weight):
if weight is not None:
ctc._set_kenlm_scorer_vwc_weight(self._scorer, weight)


class CTCBeamDecoder(BaseCTCBeamDecoder):
def __init__(self, scorer, labels, top_paths=1, beam_width=10, blank_index=0, space_index=28):
Expand All @@ -112,11 +116,15 @@ def set_label_selection_parameters(self, label_size=0, label_margin=-1):
ctc._set_label_selection_parameters(self._decoder, label_size, label_margin)


def generate_lm_dict(dictionary_path, kenlm_path, output_path, labels, blank_index=0, space_index=28):
if ctc._kenlm_enabled() != 1:
def generate_lm_dict(dictionary_path, output_path, labels, kenlm_path=None, blank_index=0, space_index=28):
if kenlm_path is not None and ctc._kenlm_enabled() != 1:
raise ImportError("pytorch-ctc not compiled with KenLM support.")
result = ctc._generate_lm_dict(labels, len(labels), blank_index, space_index, kenlm_path.encode(),
dictionary_path.encode(), output_path.encode())

result = None
if kenlm_path is not None:
result = ctc._generate_lm_dict(labels, len(labels), blank_index, space_index, kenlm_path.encode(),
dictionary_path.encode(), output_path.encode())
else:
result = ctc._generate_dict(labels, len(labels), blank_index, space_index,
dictionary_path.encode(), output_path.encode())
if result != 0:
raise ValueError("Error encountered generating dictionary")
103 changes: 85 additions & 18 deletions pytorch_ctc/src/cpu_binding.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#ifdef TORCH_BINDING
#include <iostream>
#include "ctc_beam_entry.h"
#include "ctc_beam_scorer.h"
Expand All @@ -7,6 +8,7 @@
#include "util/status.h"
#include "TH.h"
#include "cpu_binding.h"
#include "ctc_beam_scorer_dict.h"

#ifdef INCLUDE_KENLM
#include "ctc_beam_scorer_klm.h"
Expand All @@ -17,6 +19,8 @@
namespace pytorch {
using pytorch::ctc::Labels;
using pytorch::ctc::Status;
using pytorch::ctc::DictBeamScorer;
using pytorch::ctc::ctc_beam_search::DictBeamState;

#ifdef INCLUDE_KENLM
using pytorch::ctc::KenLMBeamScorer;
Expand All @@ -40,7 +44,7 @@ namespace pytorch {
return full_score_return.prob;
}

int generate_dictionary(Labels& labels, const char* kenlm_path, const char* vocab_path, const char* trie_path) {
int generate_klm_dict_trie(Labels& labels, const char* kenlm_path, const char* vocab_path, const char* trie_path) {
lm::ngram::Config config;
config.load_method = util::POPULATE_OR_READ;
Model* model = lm::ngram::LoadVirtual(kenlm_path, config);
Expand Down Expand Up @@ -76,6 +80,35 @@ namespace pytorch {
}
#endif

int generate_dict_trie(Labels& labels, const char* vocab_path, const char* trie_path) {
std::ifstream ifs;
ifs.open(vocab_path, std::ifstream::in);

ctc::TrieNode root(labels.GetSize());

if (!ifs.is_open()) {
std::cout << "unable to open vocabulary" << std::endl;
return -1;
}

std::ofstream ofs;
ofs.open(trie_path);

std::string word;
int i = 0;
while (ifs >> word) {
std::wstring wide_word;
utf8::utf8to16(word.begin(), word.end(), std::back_inserter(wide_word));
root.Insert(wide_word.c_str(), [&labels](wchar_t c) {
return labels.GetLabel(c);
}, i++, 0);
}
root.WriteToStream(ofs);
ifs.close();
ofs.close();
return 0;
}

extern "C"
{
void* get_kenlm_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
Expand All @@ -90,8 +123,18 @@ namespace pytorch {
}

void free_kenlm_scorer(void* kenlm_scorer) {
#ifdef INCLUDE_KENLM
ctc::KenLMBeamScorer* beam_scorer = static_cast<ctc::KenLMBeamScorer*>(kenlm_scorer);
delete beam_scorer;
#endif
return;
}

void* get_dict_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
const char* trie_path) {
Labels* labels = new Labels(label_str, labels_size, blank_index, space_index);
ctc::DictBeamScorer *beam_scorer = new ctc::DictBeamScorer(labels, trie_path);
return static_cast<void*>(beam_scorer);
}

void set_kenlm_scorer_lm_weight(void *scorer, float weight) {
Expand All @@ -108,13 +151,6 @@ namespace pytorch {
#endif
}

void set_kenlm_scorer_vwc_weight(void *scorer, float weight) {
#ifdef INCLUDE_KENLM
ctc::KenLMBeamScorer *beam_scorer = static_cast<ctc::KenLMBeamScorer *>(scorer);
beam_scorer->SetValidWordCountWeight(weight);
#endif
}

void set_label_selection_parameters(void *decoder, int label_selection_size, float label_selection_margin) {
ctc::CTCBeamSearchDecoder<> *beam_decoder = static_cast<ctc::CTCBeamSearchDecoder<> *>(decoder);
beam_decoder->SetLabelSelectionParameters(label_selection_size, label_selection_margin);
Expand All @@ -128,12 +164,19 @@ namespace pytorch {
void* get_ctc_beam_decoder(int num_classes, int top_paths, int beam_width, int blank_index, void *scorer, DecodeType type) {
switch (type) {
case CTC:
{
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *beam_scorer = static_cast<ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *>(scorer);
ctc::CTCBeamSearchDecoder<> *decoder = new ctc::CTCBeamSearchDecoder<>
(num_classes, beam_width, beam_scorer, blank_index);
return static_cast<void *>(decoder);
}
{
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *beam_scorer = static_cast<ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *>(scorer);
ctc::CTCBeamSearchDecoder<> *decoder = new ctc::CTCBeamSearchDecoder<>
(num_classes, beam_width, beam_scorer, blank_index);
return static_cast<void *>(decoder);
}
case CTC_DICT:
{
ctc::DictBeamScorer *beam_scorer = static_cast<ctc::DictBeamScorer *>(scorer);
ctc::CTCBeamSearchDecoder<DictBeamState> *decoder = new ctc::CTCBeamSearchDecoder<DictBeamState>
(num_classes, beam_width, beam_scorer, blank_index);
return static_cast<void *>(decoder);
}
#ifdef INCLUDE_KENLM
case CTC_KENLM:
{
Expand All @@ -148,7 +191,7 @@ namespace pytorch {
}

int ctc_beam_decode(void *void_decoder, DecodeType type, THFloatTensor *th_probs, THIntTensor *th_seq_len, THIntTensor *th_output,
THFloatTensor *th_scores, THIntTensor *th_out_len, THIntTensor *th_alignments)
THFloatTensor *th_scores, THIntTensor *th_out_len, THIntTensor *th_alignments, THFloatTensor *th_char_probs)
{
const int64_t max_time = THFloatTensor_size(th_probs, 0);
const int64_t batch_size = THFloatTensor_size(th_probs, 1);
Expand Down Expand Up @@ -182,6 +225,10 @@ namespace pytorch {
for (ctc::CTCDecoder::Output& alignment : alignments) {
alignment.resize(batch_size);
}
std::vector<ctc::CTCDecoder::CharProbability> char_probs(top_paths);
for (ctc::CTCDecoder::CharProbability& char_ : char_probs) {
char_.resize(batch_size);
}
float score[batch_size][top_paths];
memset(score, 0.0, batch_size*top_paths*sizeof(int));
Eigen::Map<Eigen::MatrixXf> *scores;
Expand All @@ -192,7 +239,17 @@ namespace pytorch {
{
ctc::CTCBeamSearchDecoder<> *decoder = static_cast<ctc::CTCBeamSearchDecoder<> *>(void_decoder);
scores = new Eigen::Map<Eigen::MatrixXf>(&score[0][0], batch_size, decoder->GetBeamWidth());
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments);
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments, &char_probs);
if (!stat.ok()) {
return 0;
}
}
break;
case CTC_DICT:
{
ctc::CTCBeamSearchDecoder<DictBeamState> *decoder = static_cast<ctc::CTCBeamSearchDecoder<DictBeamState> *>(void_decoder);
scores = new Eigen::Map<Eigen::MatrixXf>(&score[0][0], batch_size, decoder->GetBeamWidth());
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments, &char_probs);
if (!stat.ok()) {
return 0;
}
Expand All @@ -203,7 +260,7 @@ namespace pytorch {
{
ctc::CTCBeamSearchDecoder<KenLMBeamState> *decoder = static_cast<ctc::CTCBeamSearchDecoder<KenLMBeamState> *>(void_decoder);
scores = new Eigen::Map<Eigen::MatrixXf>(&score[0][0], batch_size, decoder->GetBeamWidth());
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments);
Status stat = decoder->Decode(seq_len, inputs, &outputs, scores, &alignments, &char_probs);
if (!stat.ok()) {
return 0;
}
Expand All @@ -219,6 +276,7 @@ namespace pytorch {
for (int b=0; b < batch_size; ++b) {
auto& p_batch = outputs[p][b];
auto& alignment_batch = alignments[p][b];
auto& char_prob_batch = char_probs[p][b];
int64_t num_decoded = p_batch.size();

max_decoded = std::max(max_decoded, num_decoded);
Expand All @@ -227,6 +285,7 @@ namespace pytorch {
// TODO: this could be more efficient (significant pointer arithmetic every time currently)
THIntTensor_set3d(th_output, p, b, t, p_batch[t]);
THIntTensor_set3d(th_alignments, p, b, t, alignment_batch[t]);
THFloatTensor_set3d(th_char_probs, p, b, t, char_prob_batch[t]);
THFloatTensor_set2d(th_scores, p, b, (*scores)(b, p));
}
}
Expand All @@ -239,12 +298,19 @@ namespace pytorch {
const char* lm_path, const char* dictionary_path, const char* output_path) {
#ifdef INCLUDE_KENLM
Labels labels(label_str, size, blank_index, space_index);
return generate_dictionary(labels, lm_path, dictionary_path, output_path);
return generate_klm_dict_trie(labels, lm_path, dictionary_path, output_path);
#else
return -1;
#endif
}

int generate_dict(const wchar_t* label_str, int size, int blank_index, int space_index,
const char* dictionary_path, const char* output_path) {
Labels labels(label_str, size, blank_index, space_index);
return generate_dict_trie(labels, dictionary_path, output_path);
}


int kenlm_enabled() {
#ifdef INCLUDE_KENLM
return 1;
Expand All @@ -254,3 +320,4 @@ namespace pytorch {
}
}
}
#endif
15 changes: 11 additions & 4 deletions pytorch_ctc/src/cpu_binding.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
typedef enum {
CTC,
CTC_DICT,
CTC_KENLM
} DecodeType ;


/* scorers */
int kenlm_enabled();

void* get_dict_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
const char* trie_path);
void* get_kenlm_scorer(const wchar_t* label_str, int labels_size, int space_index, int blank_index,
const char* lm_path, const char* trie_path);
void free_kenlm_scorer(void* kenlm_scorer);

void set_kenlm_scorer_lm_weight(void *scorer, float weight);
void set_kenlm_scorer_wc_weight(void *scorer, float weight);
void set_kenlm_scorer_vwc_weight(void *scorer, float weight);
// void set_kenlm_scorer_vwc_weight(void *scorer, float weight);
void set_label_selection_parameters(void *decoder, int label_selection_size, float label_selection_margin);
void* get_base_scorer();

Expand All @@ -23,11 +27,14 @@ void* get_ctc_beam_decoder(int num_classes, int top_paths, int beam_width, int b


/* run decoding */
int ctc_beam_decode(void *decoder, DecodeType type,
THFloatTensor *probs, THIntTensor *seq_len, THIntTensor *output,
THFloatTensor *scores, THIntTensor *th_out_len, THIntTensor *th_alignments);
int ctc_beam_decode(void *decoder, DecodeType type, THFloatTensor *probs,
THIntTensor *seq_len, THIntTensor *output, THFloatTensor *scores,
THIntTensor *th_out_len, THIntTensor *th_alignments, THFloatTensor *char_probs);


/* utilities */
int generate_lm_dict(const wchar_t* labels, int size, int blank_index, int space_index,
const char* lm_path, const char* dictionary_path, const char* output_path);

int generate_dict(const wchar_t* labels, int size, int blank_index, int space_index,
const char* dictionary_path, const char* output_path);
15 changes: 11 additions & 4 deletions pytorch_ctc/src/ctc_beam_entry.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,9 @@ struct BeamEntry {
}
std::vector<int> LabelSeq() const {
std::vector<int> labels;
int prev_label = -1;
const BeamEntry* c = this;
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
labels.push_back(c->label);
prev_label = c->label;
c = c->parent;
}
std::reverse(labels.begin(), labels.end());
Expand All @@ -105,17 +103,26 @@ struct BeamEntry {

std::vector<int> TimeStepSeq() const {
std::vector<int> time_steps;
int prev_label = -1;
const BeamEntry *c = this;
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
time_steps.push_back(c->time_step);
prev_label = c->label;
c = c->parent;
}
std::reverse(time_steps.begin(), time_steps.end());
return time_steps;
}

std::vector<float> CharProbSeq() const {
std::vector<float> probs;
const BeamEntry *c = this;
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
probs.push_back(c->newp.total);
c = c->parent;
}
std::reverse(probs.begin(), probs.end());
return probs;
}

BeamEntry<CTCBeamState>* parent;
int label;
int time_step;
Expand Down
Loading

0 comments on commit dd0835e

Please sign in to comment.