Skip to content

Commit

Permalink
Add log probs for all tokens (#1755)
Browse files Browse the repository at this point in the history
* add log probs

* fix compilation

* fix compilation

* fix test

* fix black

* return logits

* fix compilation

* fix test

* last clean
  • Loading branch information
minhthuc2502 authored Aug 26, 2024
1 parent 8ba828c commit 6647945
Show file tree
Hide file tree
Showing 17 changed files with 122 additions and 13 deletions.
8 changes: 7 additions & 1 deletion include/ctranslate2/decoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@ namespace ctranslate2 {
std::vector<std::vector<size_t>> hypotheses;
std::vector<float> scores;
std::vector<std::vector<std::vector<float>>> attention;
std::vector<std::vector<StorageView>> logits_vocab;
};

struct DecodingStepResult {
size_t step;
size_t batch_id;
size_t token_id;
size_t hypothesis_id;
std::optional<float> log_prob;
std::optional<float> score;
std::optional<StorageView> logits;
bool is_last = false;
};

Expand All @@ -41,6 +43,7 @@ namespace ctranslate2 {
const dim_t min_length,
const bool return_scores = false,
const bool return_attention = false,
const bool return_logits_vocab = true,
const bool return_prefix = true,
const size_t num_hypotheses = 1,
const bool include_eos_in_hypotheses = true,
Expand All @@ -67,6 +70,7 @@ namespace ctranslate2 {
const dim_t min_length,
const bool return_scores = false,
const bool return_attention = false,
const bool return_logits_vocab = true,
const bool return_prefix = true,
const size_t num_hypotheses = 1,
const bool include_eos_in_hypotheses = true,
Expand Down Expand Up @@ -118,6 +122,7 @@ namespace ctranslate2 {
const dim_t min_length,
const bool return_scores = false,
const bool return_attention = false,
const bool return_logits_vocab = true,
const bool return_prefix = true,
const size_t num_hypotheses = 1,
const bool include_eos_in_hypotheses = true,
Expand Down Expand Up @@ -149,6 +154,7 @@ namespace ctranslate2 {
bool include_eos_in_hypotheses = true;
bool return_scores = false;
bool return_attention = false;
bool return_logits_vocab = false;
bool return_alternatives = false;
bool return_prefix = true;
float min_alternative_expansion_prob = 0;
Expand Down
9 changes: 7 additions & 2 deletions include/ctranslate2/generation.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ namespace ctranslate2 {

// Include scores in the result.
bool return_scores = false;
// Include log probs of each token in the result
bool return_logits_vocab = false;

// Return alternatives at the first unconstrained decoding position. This is typically
// used with a prefix to provide alternatives at a specifc location.
Expand All @@ -79,6 +81,7 @@ namespace ctranslate2 {
std::vector<std::vector<std::string>> sequences;
std::vector<std::vector<size_t>> sequences_ids;
std::vector<float> scores;
std::vector<std::vector<StorageView>> logits;

size_t num_sequences() const {
return sequences.size();
Expand All @@ -95,7 +98,8 @@ namespace ctranslate2 {
size_t token_id;
size_t hypothesis_id;
std::string token;
std::optional<float> log_prob;
std::optional<float> score;
std::optional<StorageView> logits;
bool is_last;

GenerationStepResult() = default;
Expand All @@ -105,7 +109,8 @@ namespace ctranslate2 {
, token_id(result.token_id)
, hypothesis_id(result.hypothesis_id)
, token(vocabulary.to_token(result.token_id))
, log_prob(result.log_prob)
, score(result.score)
, logits(result.logits)
, is_last(result.is_last)
{
}
Expand Down
4 changes: 4 additions & 0 deletions include/ctranslate2/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ namespace ctranslate2 {
// Include scores in the result.
bool return_scores = false;

// Include log probs of each token in the result
bool return_logits_vocab = false;

// Include the probability of the no speech token in the result.
bool return_no_speech_prob = false;

Expand All @@ -59,6 +62,7 @@ namespace ctranslate2 {
std::vector<std::vector<std::string>> sequences;
std::vector<std::vector<size_t>> sequences_ids;
std::vector<float> scores;
std::vector<std::vector<StorageView>> logits;
float no_speech_prob = 0;

size_t num_sequences() const {
Expand Down
7 changes: 6 additions & 1 deletion include/ctranslate2/translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ namespace ctranslate2 {
bool return_scores = false;
// Store attention vectors in the TranslationResult class.
bool return_attention = false;
// Store log probs matrix in the TranslationResult class.
bool return_logits_vocab = false;

// Return alternatives at the first unconstrained decoding position. This is typically
// used with a target prefix to provide alternatives at a specifc location in the
Expand All @@ -87,6 +89,7 @@ namespace ctranslate2 {
std::vector<std::vector<std::string>> hypotheses;
std::vector<float> scores;
std::vector<std::vector<std::vector<float>>> attention;
std::vector<std::vector<StorageView>> logits;

TranslationResult(std::vector<std::vector<std::string>> hypotheses_)
: hypotheses(std::move(hypotheses_))
Expand All @@ -95,10 +98,12 @@ namespace ctranslate2 {

TranslationResult(std::vector<std::vector<std::string>> hypotheses_,
std::vector<float> scores_,
std::vector<std::vector<std::vector<float>>> attention_)
std::vector<std::vector<std::vector<float>>> attention_,
std::vector<std::vector<StorageView>> logits_)
: hypotheses(std::move(hypotheses_))
, scores(std::move(scores_))
, attention(std::move(attention_))
, logits(std::move(logits_))
{
}

Expand Down
10 changes: 8 additions & 2 deletions python/cpp/generation_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ namespace ctranslate2 {
"Index of the hypothesis in the batch.")
.def_readonly("token", &GenerationStepResult::token,
"String value of the generated token.")
.def_readonly("log_prob", &GenerationStepResult::log_prob,
.def_readonly("log_prob", &GenerationStepResult::score,
"Log probability of the token (``None`` if :obj:`return_log_prob` was disabled).")
.def_readonly("logits", &GenerationStepResult::logits,
"Log probability on the vocab of all tokens.")
.def_readonly("is_last", &GenerationStepResult::is_last,
"Whether this step is the last decoding step for this batch.")

Expand All @@ -32,7 +34,8 @@ namespace ctranslate2 {
+ ", token_id=" + std::string(py::repr(py::cast(result.token_id)))
+ ", hypothesis_id=" + std::string(py::repr(py::cast(result.hypothesis_id)))
+ ", token=" + std::string(py::repr(py::cast(result.token)))
+ ", log_prob=" + std::string(py::repr(py::cast(result.log_prob)))
+ ", log_prob=" + std::string(py::repr(py::cast(result.score)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ", is_last=" + std::string(py::repr(py::cast(result.is_last)))
+ ")";
})
Expand All @@ -46,11 +49,14 @@ namespace ctranslate2 {
"Generated sequences of token IDs.")
.def_readonly("scores", &GenerationResult::scores,
"Score of each sequence (empty if :obj:`return_scores` was disabled).")
.def_readonly("logits", &GenerationResult::logits,
"Score of each sequence (empty if :obj:`return_logits_vocab` was disabled).")

.def("__repr__", [](const GenerationResult& result) {
return "GenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences)))
+ ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids)))
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ")";
})
;
Expand Down
4 changes: 4 additions & 0 deletions python/cpp/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace ctranslate2 {
bool cache_static_prompt,
bool include_prompt_in_result,
bool return_scores,
bool return_logits_vocab,
bool return_alternatives,
float min_alternative_expansion_prob,
size_t sampling_topk,
Expand All @@ -58,6 +59,7 @@ namespace ctranslate2 {
options.num_hypotheses = num_hypotheses;
options.return_end_token = return_end_token;
options.return_scores = return_scores;
options.return_logits_vocab = return_logits_vocab;
options.return_alternatives = return_alternatives;
options.cache_static_prompt = cache_static_prompt;
options.include_prompt_in_result = include_prompt_in_result;
Expand Down Expand Up @@ -203,6 +205,7 @@ namespace ctranslate2 {
py::arg("cache_static_prompt")=true,
py::arg("include_prompt_in_result")=true,
py::arg("return_scores")=false,
py::arg("return_logits_vocab")=false,
py::arg("return_alternatives")=false,
py::arg("min_alternative_expansion_prob")=0,
py::arg("sampling_topk")=1,
Expand Down Expand Up @@ -260,6 +263,7 @@ namespace ctranslate2 {
reuse it for future generations using the same static prompt.
include_prompt_in_result: Include the :obj:`start_tokens` in the result.
return_scores: Include the scores in the output.
return_logits_vocab: Include log probs for each token in the output
return_alternatives: Return alternatives at the first unconstrained decoding position.
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.
sampling_topk: Randomly sample predictions from the top K candidates.
Expand Down
6 changes: 6 additions & 0 deletions python/cpp/storage_view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ namespace ctranslate2 {
return stream.str();
})

.def("__repr__", [](const StorageView& view) {
std::ostringstream stream;
stream << view;
return stream.str();
})

.def("to",
[](const StorageView& view, DataType dtype) {
ScopedDeviceSetter device_setter(view.device(), view.device_index());
Expand Down
3 changes: 3 additions & 0 deletions python/cpp/translation_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ namespace ctranslate2 {
"Score of each translation hypothesis (empty if :obj:`return_scores` was disabled).")
.def_readonly("attention", &TranslationResult::attention,
"Attention matrix of each translation hypothesis (empty if :obj:`return_attention` was disabled).")
.def_readonly("logits", &TranslationResult::logits,
"Score of each translation hypothesis (empty if :obj:`return_logits_vocab` was disabled).")

.def("__repr__", [](const TranslationResult& result) {
return "TranslationResult(hypotheses=" + std::string(py::repr(py::cast(result.hypotheses)))
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
+ ", attention=" + std::string(py::repr(py::cast(result.attention)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ")";
})

Expand Down
4 changes: 4 additions & 0 deletions python/cpp/translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ namespace ctranslate2 {
size_t min_decoding_length,
bool use_vmap,
bool return_scores,
bool return_logits_vocab,
bool return_attention,
bool return_alternatives,
float min_alternative_expansion_prob,
Expand Down Expand Up @@ -172,6 +173,7 @@ namespace ctranslate2 {
options.use_vmap = use_vmap;
options.return_end_token = return_end_token;
options.return_scores = return_scores;
options.return_logits_vocab = return_logits_vocab;
options.return_attention = return_attention;
options.return_alternatives = return_alternatives;
options.min_alternative_expansion_prob = min_alternative_expansion_prob;
Expand Down Expand Up @@ -354,6 +356,7 @@ namespace ctranslate2 {
py::arg("min_decoding_length")=1,
py::arg("use_vmap")=false,
py::arg("return_scores")=false,
py::arg("return_logits_vocab")=false,
py::arg("return_attention")=false,
py::arg("return_alternatives")=false,
py::arg("min_alternative_expansion_prob")=0,
Expand Down Expand Up @@ -396,6 +399,7 @@ namespace ctranslate2 {
min_decoding_length: Minimum prediction length.
use_vmap: Use the vocabulary mapping file saved in this model
return_scores: Include the scores in the output.
return_logits_vocab: Include the log probs of each token in the output
return_attention: Include the attention vectors in the output.
return_alternatives: Return alternatives at the first unconstrained decoding position.
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.
Expand Down
4 changes: 4 additions & 0 deletions python/cpp/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace ctranslate2 {
size_t no_repeat_ngram_size,
size_t max_length,
bool return_scores,
bool return_logits_vocab,
bool return_no_speech_prob,
size_t max_initial_timestamp_index,
bool suppress_blank,
Expand All @@ -59,6 +60,7 @@ namespace ctranslate2 {
options.max_length = max_length;
options.num_hypotheses = num_hypotheses;
options.return_scores = return_scores;
options.return_logits_vocab = return_logits_vocab;
options.return_no_speech_prob = return_no_speech_prob;
options.max_initial_timestamp_index = max_initial_timestamp_index;
options.suppress_blank = suppress_blank;
Expand Down Expand Up @@ -247,6 +249,7 @@ namespace ctranslate2 {
py::arg("no_repeat_ngram_size")=0,
py::arg("max_length")=448,
py::arg("return_scores")=false,
py::arg("return_logits_vocab")=false,
py::arg("return_no_speech_prob")=false,
py::arg("max_initial_timestamp_index")=50,
py::arg("suppress_blank")=true,
Expand Down Expand Up @@ -276,6 +279,7 @@ namespace ctranslate2 {
(set 0 to disable).
max_length: Maximum generation length.
return_scores: Include the scores in the output.
return_logits_vocab: Include the log probs in the output
return_no_speech_prob: Include the probability of the no speech token in the
result.
max_initial_timestamp_index: Maximum index of the first predicted timestamp.
Expand Down
9 changes: 6 additions & 3 deletions python/tests/test_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,12 @@ def test_batch_translation(max_batch_size):
assert output[0].scores[0] < 0
assert not output[0].attention

expected_repr = "TranslationResult(hypotheses=%s, scores=%s, attention=[])" % (
output[0].hypotheses,
output[0].scores,
expected_repr = (
"TranslationResult(hypotheses=%s, scores=%s, attention=[], logits=[])"
% (
output[0].hypotheses,
output[0].scores,
)
)
assert repr(output[0]) == expected_repr

Expand Down
Loading

0 comments on commit 6647945

Please sign in to comment.