Skip to content

Commit

Permalink
Rename to ctcdecode
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanleary committed Nov 7, 2017
1 parent dd0835e commit c02a771
Show file tree
Hide file tree
Showing 22 changed files with 92 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,4 @@ ENV/
*.out
*.app

pytorch_ctc/ctcdecode.xcodeproj/
ctcdecode/ctcdecode.xcodeproj/
43 changes: 24 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# pytorch-ctc
PyTorch-CTC is an implementation of CTC (Connectionist Temporal Classification) beam search decoding for PyTorch. C++ code borrowed liberally from TensorFlow with some improvements to increase flexibility.
# ctcdecode
ctcdecode is an implementation of CTC (Connectionist Temporal Classification) beam search decoding for PyTorch.
C++ code borrowed liberally from TensorFlow with some improvements to increase flexibility.
It includes swappable scorer support enabling standard beam search, dictionary-based decoding, and KenLM-based decoding.

## Installation
The library is largely self-contained and requires only PyTorch and CFFI. Building the C++ library requires gcc. KenLM language modeling support is also optionally included, and enabled by default.
The library is largely self-contained and requires only PyTorch and CFFI. Building the C++ library requires gcc or clang. KenLM language modeling support is also optionally included, and enabled by default.

```bash
# get the code
git clone --recursive https://github.com/ryanleary/pytorch-ctc.git
cd pytorch-ctc
git clone --recursive https://github.com/parlance/ctcdecode.git
cd ctcdecode

# install dependencies (PyTorch and CFFI)
pip install -r requirements.txt
Expand All @@ -18,16 +20,21 @@ python setup.py install
```

## API
pytorch-ctc includes a CTC beam search decoder with multiple scorer implementations. A `scorer` is a function that the decoder calls to condition the probability of a given beam based on its state.
ctcdecode includes a CTC beam search decoder with multiple scorer implementations. A `scorer` is a function that the decoder calls to condition the probability of a given beam based on its state.

### Scorers
Two Scorer implementations are currently implemented for pytorch-ctc.
Three Scorer implementations are currently implemented for ctcdecode.

**Scorer:** is a NO-OP and enables the decoder to do a vanilla beam decode
```python
scorer = Scorer()
```

**DictScorer:** conditions beams based on the provided dictionary trie. Only words in the dictionary will be hypothesized.
```python
scorer = DictScorer(labels, trie_path, blank_index=0, space_index=28):
```

**KenLMScorer:** conditions beams based on the provided KenLM binary language model.
```python
scorer = KenLMScorer(labels, lm_path, trie_path, blank_index=0, space_index=28)
Expand All @@ -36,22 +43,21 @@ scorer = KenLMScorer(labels, lm_path, trie_path, blank_index=0, space_index=28)
where:
- `labels` is a string of output labels given in the same order as the output layer
- `lm_path` path to a binary KenLM language model for decoding
- `trie_path` path to a Trie containing the lexicon (see generate_lm_trie)
- `trie_path` path to a Trie containing the lexicon (see generate_lm_dict)
- `blank_index` is used to specify which position in the output distribution represents the `blank` class
- `space_index` is used to specify which position in the output distribution represents the word separator class

The `KenLMScorer` may be further configured with weights for the language model contribution to the score (`lm_weight`), as well as word and valid word bonuses (to offset decreasing probability as a function of sequence length).
The `KenLMScorer` may be further configured with weights for the language model contribution to the score (`lm_weight`), as well as word bonuses (to offset decreasing probability as a function of sequence length).

```python
scorer.set_lm_weight(2.1)
scorer.set_word_weight(1.1)
scorer.set_valid_word_weight(1.5)
scorer.set_lm_weight(2.0)
scorer.set_word_weight(0.1)
```

### Decoder
```python
decoder = CTCBeamDecoder(scorer, labels, top_paths=3, beam_width=20,
blank_index=0, space_index=28, merge_repeated=False)
blank_index=0, space_index=28)
```

where:
Expand All @@ -61,10 +67,9 @@ where:
- `beam_width` is the number of beams to evaluate in a given step
- `blank_index` is used to specify which position in the output distribution represents the `blank` class
- `space_index` is used to specify which position in the output distribution represents the word separator class
- `merge_repeated` if True will collapse repeated characters

```python
output, score, out_seq_len = decoder.decode(probs, sizes=None)
output, score, out_seq_len, offsets = decoder.decode(probs, sizes=None)
```

where:
Expand All @@ -75,9 +80,9 @@ and returns:
- `output` is an IntTensor of character classes of shape `(top_paths, batch_size, seq_len)`
- `score` is a FloatTensor of log-probabilities representing the likelihood of the transcription with shape `(top_paths, batch_size)`
- `out_seq_len` is an IntTensor containing the length of the output sequence with shape `(top_paths, batch_size)`
- `offsets` is an IntTensor returning the index of the input at which the character occurs. Can be used for generating time alignments

The `CTCBeamDecoder` may be further configured with weights for the label size (`label_size`), and label margin ('label_margin'). These parameters helps to reduce
the computation time.
The `CTCBeamDecoder` may be further configured with weights for the label size (`label_size`), and label margin (`label_margin`). These parameters helps to reduce the computation time.

Label selection size controls how many items in each beam are passed through to the beam scorer. Only items with top N input scores are considered.
Label selection margin controls the difference between minimal input score (versus the best scoring label) for an item to be passed to the beam scorer. This margin is expressed in terms of log-probability. Default is to do no label selection.
Expand All @@ -88,10 +93,10 @@ decoder.set_label_selection_parameters(label_size=0, label_margin=6)

### Utilities
```python
generate_lm_dict(dictionary_path, kenlm_path, output_path, labels, blank_index, space_index)
generate_lm_dict(dictionary_path, output_path, labels, kenlm_path=None, blank_index=0, space_index=1)
```

A vocabulary trie is required for the KenLM Scorer. The trie is created from a lexicon specified as a newline separated text file of words in the vocabulary.
A vocabulary trie is required for the KenLM Scorer. The trie is created from a lexicon specified as a newline separated text file of words in the vocabulary. The DictScorer also requires this function be run to generate a dictionary trie. In this case, a `kenlm_path` is not required.

## Acknowledgements
Thanks to [ebrevdo](https://github.com/ebrevdo) for the original TensorFlow CTC decoder implementation, [timediv](https://github.com/timediv) for his KenLM extension, and [SeanNaren](https://github.com/seannaren) for his assistance.
61 changes: 31 additions & 30 deletions pytorch_ctc/__init__.py → ctcdecode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import torch
import pytorch_ctc as ctc
import ctcdecode as ctc
from torch.utils.ffi import _wrap_function
from ._ctc_decode import lib as _lib, ffi as _ffi

__all__ = []


def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
new_symbol = "_" + symbol
locals[new_symbol] = _wrap_function(fn, _ffi)
__all__.append(new_symbol)


_import_symbols(locals())
from ._ext import ctc_decode
# from ._ext._ctc_decode import lib as _lib, ffi as _ffi
#
# __all__ = []
#
#
# def _import_symbols(locals):
# for symbol in dir(_lib):
# fn = getattr(_lib, symbol)
# new_symbol = "_" + symbol
# locals[new_symbol] = _wrap_function(fn, _ffi)
# __all__.append(new_symbol)
#
#
# _import_symbols(locals())


class BaseCTCBeamDecoder(object):
Expand Down Expand Up @@ -49,7 +50,7 @@ def decode(self, probs, seq_len=None):
alignments = torch.IntTensor(self._top_paths, batch_size, max_seq_len)
char_probs = torch.FloatTensor(self._top_paths, batch_size, max_seq_len)

result = ctc._ctc_beam_decode(self._decoder, self._decoder_type, probs, seq_len, output, scores, out_seq_len,
result = ctc_decode.ctc_beam_decode(self._decoder, self._decoder_type, probs, seq_len, output, scores, out_seq_len,
alignments, char_probs)

return output, scores, out_seq_len, alignments, char_probs
Expand All @@ -70,37 +71,37 @@ def get_scorer(self):
class Scorer(BaseScorer):
def __init__(self):
super(Scorer, self).__init__()
self._scorer = ctc._get_base_scorer()
self._scorer = ctc_decode.get_base_scorer()


class DictScorer(BaseScorer):
def __init__(self, labels, trie_path, blank_index=0, space_index=28):
super(DictScorer, self).__init__()
self._scorer_type = 1
self._scorer = ctc._get_dict_scorer(labels, len(labels), space_index, blank_index, trie_path.encode())
self._scorer = ctc_decode.get_dict_scorer(labels, len(labels), space_index, blank_index, trie_path.encode())


class KenLMScorer(BaseScorer):
def __init__(self, labels, lm_path, trie_path, blank_index=0, space_index=28):
super(KenLMScorer, self).__init__()
if ctc._kenlm_enabled() != 1:
if ctc_decode.kenlm_enabled() != 1:
raise ImportError("pytorch-ctc not compiled with KenLM support.")
self._scorer_type = 2
self._scorer = ctc._get_kenlm_scorer(labels, len(labels), space_index, blank_index, lm_path.encode(),
self._scorer = ctc_decode.get_kenlm_scorer(labels, len(labels), space_index, blank_index, lm_path.encode(),
trie_path.encode())

# This is a way to make sure the destructor is called for the C++ object
# Frees all the member data items that have allocated memory
def __del__(self):
ctc._free_kenlm_scorer(self._scorer)
ctc_decode.free_kenlm_scorer(self._scorer)

def set_lm_weight(self, weight):
if weight is not None:
ctc._set_kenlm_scorer_lm_weight(self._scorer, weight)
ctc_decode.set_kenlm_scorer_lm_weight(self._scorer, weight)

def set_word_weight(self, weight):
if weight is not None:
ctc._set_kenlm_scorer_wc_weight(self._scorer, weight)
ctc_decode.set_kenlm_scorer_wc_weight(self._scorer, weight)


class CTCBeamDecoder(BaseCTCBeamDecoder):
Expand All @@ -109,22 +110,22 @@ def __init__(self, scorer, labels, top_paths=1, beam_width=10, blank_index=0, sp
blank_index=blank_index, space_index=space_index)
self._scorer = scorer
self._decoder_type = self._scorer.get_scorer_type()
self._decoder = ctc._get_ctc_beam_decoder(self._num_classes, top_paths, beam_width, blank_index,
self._scorer.get_scorer(), self._decoder_type)
self._decoder = ctc_decode.get_ctc_beam_decoder(self._num_classes, top_paths, beam_width, blank_index,
self._scorer.get_scorer(), self._decoder_type)

def set_label_selection_parameters(self, label_size=0, label_margin=-1):
ctc._set_label_selection_parameters(self._decoder, label_size, label_margin)
ctc_decode.set_label_selection_parameters(self._decoder, label_size, label_margin)


def generate_lm_dict(dictionary_path, output_path, labels, kenlm_path=None, blank_index=0, space_index=28):
if kenlm_path is not None and ctc._kenlm_enabled() != 1:
raise ImportError("pytorch-ctc not compiled with KenLM support.")
result = None
if kenlm_path is not None:
result = ctc._generate_lm_dict(labels, len(labels), blank_index, space_index, kenlm_path.encode(),
dictionary_path.encode(), output_path.encode())
result = ctc_decode.generate_lm_dict(labels, len(labels), blank_index, space_index, kenlm_path.encode(),
dictionary_path.encode(), output_path.encode())
else:
result = ctc._generate_dict(labels, len(labels), blank_index, space_index,
dictionary_path.encode(), output_path.encode())
result = ctc_decode.generate_dict(labels, len(labels), blank_index, space_index,
dictionary_path.encode(), output_path.encode())
if result != 0:
raise ValueError("Error encountered generating dictionary")
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
void Reset();

// Extract the top n paths at current time step
Status TopPaths(int n,
Status TopPaths(unsigned long n,
std::vector<std::vector<int>>* paths,
std::vector<float>* beam_probs,
std::vector<std::vector<int>>* alignments,
Expand Down Expand Up @@ -156,13 +156,13 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
CTCDecoder::ScoreOutput* beam_probs,
std::vector<CTCDecoder::Output>* alignment,
std::vector<CTCDecoder::CharProbability>* char_probs) {
int batch_size_ = input[0].rows();
int batch_size_ = (int)input[0].rows();
// Storage for top paths.
std::vector<std::vector<int>> beams;
std::vector<float> beam_log_probabilities;
std::vector<std::vector<float>> char_log_probabilities;
std::vector<std::vector<int>> beam_alignments;
int top_n = output->size();
unsigned long top_n = output->size();

// check data structure shapes
if (std::any_of(output->begin(), output->end(),
Expand Down Expand Up @@ -273,6 +273,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(const Vector& raw
b->newp.label =
LogSumExp(b->newp.label,
beam_scorer_->GetStateExpansionScore(b->state, previous));
b->time_step = time_step;
}
// Plabel(l=abc @ t=6) *= P(c @ 6)
b->newp.label += input(b->label);
Expand All @@ -284,7 +285,6 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(const Vector& raw

// Push the entry back to the top paths list.
// Note, this will always fill leaves back up in sorted order.
b->time_step = time_step;
leaves_.push(b);
}

Expand Down Expand Up @@ -386,7 +386,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
// generate the n-best character list.
template <typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
int n,
unsigned long n,
std::vector<std::vector<int>>* paths,
std::vector<float>* beam_probs,
std::vector<std::vector<int>>* alignments,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
27 changes: 15 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import platform
import glob

from distutils.core import setup, Extension
from setuptools import setup, find_packages
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

#Does gcc compile with this header and library?
def compile_test(header, library):
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
Expand Down Expand Up @@ -44,36 +46,37 @@ def compile_test(header, library):
lib_sources = glob.glob('third_party/kenlm/util/*.cc') + glob.glob('third_party/kenlm/lm/*.cc') + glob.glob('third_party/kenlm/util/double-conversion/*.cc')
lib_sources = [fn for fn in lib_sources if not (fn.endswith('main.cc') or fn.endswith('test.cc'))]

third_party_includes=["third_party/" + lib for lib in third_party_libs]
ctc_sources = ['pytorch_ctc/src/cpu_binding.cpp', 'pytorch_ctc/src/util/status.cpp']
ctc_headers = ['pytorch_ctc/src/cpu_binding.h',]
third_party_includes=[os.path.realpath(os.path.join("third_party", lib)) for lib in third_party_libs]
ctc_sources = ['ctcdecode/src/cpu_binding.cpp', 'ctcdecode/src/util/status.cpp']
ctc_headers = ['ctcdecode/src/cpu_binding.h',]

ffi = create_extension(
name='ctc_decode',
name='ctcdecode._ext.ctc_decode',
package=True,
language='c++',
headers=ctc_headers,
sources=ctc_sources + lib_sources,
include_dirs=third_party_includes,
with_cuda=False,
libraries=ext_libs,
extra_compile_args=compile_args#, '-DINCLUDE_KENLM']
extra_compile_args=compile_args,
relative_to=__file__,
with_cuda=False
)
ffi = ffi.distutils_extension()
ffi.name = 'pytorch_ctc._ctc_decode'

setup(
name="pytorch_ctc",
version="0.1",
name="ctcdecode",
version="0.2",
description="CTC Decoder for PyTorch based on TensorFlow's implementation",
url="https://github.com/ryanleary/pytorch-ctc-decode",
url="https://github.com/parlance/ctcdecode",
author="Ryan Leary",
author_email="ryanleary@gmail.com",
# Require cffi.
install_requires=["cffi>=1.0.0"],
setup_requires=["cffi>=1.0.0"],
# Exclude the build files.
packages=["pytorch_ctc"],
packages=find_packages(exclude=["build"]),
# Extensions to compile.
ext_modules=[ffi]

)
Loading

0 comments on commit c02a771

Please sign in to comment.