Skip to content

Commit

Permalink
Merge pull request #32 from SeanNaren/label_params
Browse files Browse the repository at this point in the history
Expose label selection parameters
  • Loading branch information
ryanleary authored Oct 30, 2017
2 parents d2e728c + e639c53 commit 60dbbaf
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 0 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ and returns:
- `score` is a FloatTensor of log-probabilities representing the likelihood of the transcription with shape `(top_paths, batch_size)`
- `out_seq_len` is an IntTensor containing the length of the output sequence with shape `(top_paths, batch_size)`

The `CTCBeamDecoder` may be further configured with weights for the label size (`label_size`), and label margin ('label_margin'). These parameters helps to reduce
the computation time.

Label selection size controls how many items in each beam are passed through to the beam scorer. Only items with top N input scores are considered.
Label selection margin controls the difference between minimal input score (versus the best scoring label) for an item to be passed to the beam scorer. This margin is expressed in terms of log-probability. Default is to do no label selection.

```python
decoder.set_label_selection_parameters(label_size=0, label_margin=6)
```

### Utilities
```python
generate_lm_dict(dictionary_path, kenlm_path, output_path, labels, blank_index, space_index)
Expand Down
3 changes: 3 additions & 0 deletions pytorch_ctc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def __init__(self, scorer, labels, top_paths=1, beam_width=10, blank_index=0, sp
self._decoder = ctc._get_ctc_beam_decoder(self._num_classes, top_paths, beam_width, blank_index,
self._scorer.get_scorer(), self._decoder_type)

def set_label_selection_parameters(self, label_size=0, label_margin=-1):
ctc._set_label_selection_parameters(self._decoder, label_size, label_margin)


def generate_lm_dict(dictionary_path, kenlm_path, output_path, labels, blank_index=0, space_index=28):
if ctc._kenlm_enabled() != 1:
Expand Down
5 changes: 5 additions & 0 deletions pytorch_ctc/src/cpu_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ namespace pytorch {
#endif
}

void set_label_selection_parameters(void *decoder, int label_selection_size, float label_selection_margin) {
ctc::CTCBeamSearchDecoder<> *beam_decoder = static_cast<ctc::CTCBeamSearchDecoder<> *>(decoder);
beam_decoder->SetLabelSelectionParameters(label_selection_size, label_selection_margin);
}

void* get_base_scorer() {
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer *beam_scorer = new ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer();
return static_cast<void *>(beam_scorer);
Expand Down
1 change: 1 addition & 0 deletions pytorch_ctc/src/cpu_binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ void free_kenlm_scorer(void* kenlm_scorer);
void set_kenlm_scorer_lm_weight(void *scorer, float weight);
void set_kenlm_scorer_wc_weight(void *scorer, float weight);
void set_kenlm_scorer_vwc_weight(void *scorer, float weight);
void set_label_selection_parameters(void *decoder, int label_selection_size, float label_selection_margin);
void* get_base_scorer();


Expand Down

0 comments on commit 60dbbaf

Please sign in to comment.