Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add log probs for all tokens #1755

Merged
merged 9 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion include/ctranslate2/decoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@ namespace ctranslate2 {
std::vector<std::vector<size_t>> hypotheses;
std::vector<float> scores;
std::vector<std::vector<std::vector<float>>> attention;
std::vector<std::vector<StorageView>> logits_vocab;
};

struct DecodingStepResult {
size_t step;
size_t batch_id;
size_t token_id;
size_t hypothesis_id;
std::optional<float> log_prob;
std::optional<float> score;
std::optional<StorageView> logits;
bool is_last = false;
};

Expand All @@ -41,6 +43,7 @@ namespace ctranslate2 {
const dim_t min_length,
const bool return_scores = false,
const bool return_attention = false,
const bool return_logits_vocab = true,
const bool return_prefix = true,
const size_t num_hypotheses = 1,
const bool include_eos_in_hypotheses = true,
Expand All @@ -67,6 +70,7 @@ namespace ctranslate2 {
const dim_t min_length,
const bool return_scores = false,
const bool return_attention = false,
const bool return_logits_vocab = true,
const bool return_prefix = true,
const size_t num_hypotheses = 1,
const bool include_eos_in_hypotheses = true,
Expand Down Expand Up @@ -118,6 +122,7 @@ namespace ctranslate2 {
const dim_t min_length,
const bool return_scores = false,
const bool return_attention = false,
const bool return_logits_vocab = true,
const bool return_prefix = true,
const size_t num_hypotheses = 1,
const bool include_eos_in_hypotheses = true,
Expand Down Expand Up @@ -149,6 +154,7 @@ namespace ctranslate2 {
bool include_eos_in_hypotheses = true;
bool return_scores = false;
bool return_attention = false;
bool return_logits_vocab = false;
bool return_alternatives = false;
bool return_prefix = true;
float min_alternative_expansion_prob = 0;
Expand Down
9 changes: 7 additions & 2 deletions include/ctranslate2/generation.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ namespace ctranslate2 {

// Include scores in the result.
bool return_scores = false;
// Include log probs of each token in the result
bool return_logits_vocab = false;

// Return alternatives at the first unconstrained decoding position. This is typically
// used with a prefix to provide alternatives at a specifc location.
Expand All @@ -79,6 +81,7 @@ namespace ctranslate2 {
std::vector<std::vector<std::string>> sequences;
std::vector<std::vector<size_t>> sequences_ids;
std::vector<float> scores;
std::vector<std::vector<StorageView>> logits;

size_t num_sequences() const {
return sequences.size();
Expand All @@ -95,7 +98,8 @@ namespace ctranslate2 {
size_t token_id;
size_t hypothesis_id;
std::string token;
std::optional<float> log_prob;
std::optional<float> score;
std::optional<StorageView> logits;
bool is_last;

GenerationStepResult() = default;
Expand All @@ -105,7 +109,8 @@ namespace ctranslate2 {
, token_id(result.token_id)
, hypothesis_id(result.hypothesis_id)
, token(vocabulary.to_token(result.token_id))
, log_prob(result.log_prob)
, score(result.score)
, logits(result.logits)
, is_last(result.is_last)
{
}
Expand Down
4 changes: 4 additions & 0 deletions include/ctranslate2/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ namespace ctranslate2 {
// Include scores in the result.
bool return_scores = false;

// Include log probs of each token in the result
bool return_logits_vocab = false;

// Include the probability of the no speech token in the result.
bool return_no_speech_prob = false;

Expand All @@ -59,6 +62,7 @@ namespace ctranslate2 {
std::vector<std::vector<std::string>> sequences;
std::vector<std::vector<size_t>> sequences_ids;
std::vector<float> scores;
std::vector<std::vector<StorageView>> logits;
float no_speech_prob = 0;

size_t num_sequences() const {
Expand Down
7 changes: 6 additions & 1 deletion include/ctranslate2/translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ namespace ctranslate2 {
bool return_scores = false;
// Store attention vectors in the TranslationResult class.
bool return_attention = false;
// Store log probs matrix in the TranslationResult class.
bool return_logits_vocab = false;

// Return alternatives at the first unconstrained decoding position. This is typically
// used with a target prefix to provide alternatives at a specifc location in the
Expand All @@ -87,6 +89,7 @@ namespace ctranslate2 {
std::vector<std::vector<std::string>> hypotheses;
std::vector<float> scores;
std::vector<std::vector<std::vector<float>>> attention;
std::vector<std::vector<StorageView>> logits;

TranslationResult(std::vector<std::vector<std::string>> hypotheses_)
: hypotheses(std::move(hypotheses_))
Expand All @@ -95,10 +98,12 @@ namespace ctranslate2 {

TranslationResult(std::vector<std::vector<std::string>> hypotheses_,
std::vector<float> scores_,
std::vector<std::vector<std::vector<float>>> attention_)
std::vector<std::vector<std::vector<float>>> attention_,
std::vector<std::vector<StorageView>> logits_)
: hypotheses(std::move(hypotheses_))
, scores(std::move(scores_))
, attention(std::move(attention_))
, logits(std::move(logits_))
{
}

Expand Down
10 changes: 8 additions & 2 deletions python/cpp/generation_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ namespace ctranslate2 {
"Index of the hypothesis in the batch.")
.def_readonly("token", &GenerationStepResult::token,
"String value of the generated token.")
.def_readonly("log_prob", &GenerationStepResult::log_prob,
.def_readonly("log_prob", &GenerationStepResult::score,
"Log probability of the token (``None`` if :obj:`return_log_prob` was disabled).")
.def_readonly("logits", &GenerationStepResult::logits,
"Log probability on the vocab of all tokens.")
.def_readonly("is_last", &GenerationStepResult::is_last,
"Whether this step is the last decoding step for this batch.")

Expand All @@ -32,7 +34,8 @@ namespace ctranslate2 {
+ ", token_id=" + std::string(py::repr(py::cast(result.token_id)))
+ ", hypothesis_id=" + std::string(py::repr(py::cast(result.hypothesis_id)))
+ ", token=" + std::string(py::repr(py::cast(result.token)))
+ ", log_prob=" + std::string(py::repr(py::cast(result.log_prob)))
+ ", log_prob=" + std::string(py::repr(py::cast(result.score)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ", is_last=" + std::string(py::repr(py::cast(result.is_last)))
+ ")";
})
Expand All @@ -46,11 +49,14 @@ namespace ctranslate2 {
"Generated sequences of token IDs.")
.def_readonly("scores", &GenerationResult::scores,
"Score of each sequence (empty if :obj:`return_scores` was disabled).")
.def_readonly("logits", &GenerationResult::logits,
"Score of each sequence (empty if :obj:`return_logits_vocab` was disabled).")

.def("__repr__", [](const GenerationResult& result) {
return "GenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences)))
+ ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids)))
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ")";
})
;
Expand Down
4 changes: 4 additions & 0 deletions python/cpp/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace ctranslate2 {
bool cache_static_prompt,
bool include_prompt_in_result,
bool return_scores,
bool return_logits_vocab,
bool return_alternatives,
float min_alternative_expansion_prob,
size_t sampling_topk,
Expand All @@ -58,6 +59,7 @@ namespace ctranslate2 {
options.num_hypotheses = num_hypotheses;
options.return_end_token = return_end_token;
options.return_scores = return_scores;
options.return_logits_vocab = return_logits_vocab;
options.return_alternatives = return_alternatives;
options.cache_static_prompt = cache_static_prompt;
options.include_prompt_in_result = include_prompt_in_result;
Expand Down Expand Up @@ -203,6 +205,7 @@ namespace ctranslate2 {
py::arg("cache_static_prompt")=true,
py::arg("include_prompt_in_result")=true,
py::arg("return_scores")=false,
py::arg("return_logits_vocab")=false,
py::arg("return_alternatives")=false,
py::arg("min_alternative_expansion_prob")=0,
py::arg("sampling_topk")=1,
Expand Down Expand Up @@ -260,6 +263,7 @@ namespace ctranslate2 {
reuse it for future generations using the same static prompt.
include_prompt_in_result: Include the :obj:`start_tokens` in the result.
return_scores: Include the scores in the output.
return_logits_vocab: Include log probs for each token in the output
return_alternatives: Return alternatives at the first unconstrained decoding position.
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.
sampling_topk: Randomly sample predictions from the top K candidates.
Expand Down
6 changes: 6 additions & 0 deletions python/cpp/storage_view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ namespace ctranslate2 {
return stream.str();
})

.def("__repr__", [](const StorageView& view) {
std::ostringstream stream;
stream << view;
return stream.str();
})

.def("to",
[](const StorageView& view, DataType dtype) {
ScopedDeviceSetter device_setter(view.device(), view.device_index());
Expand Down
3 changes: 3 additions & 0 deletions python/cpp/translation_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ namespace ctranslate2 {
"Score of each translation hypothesis (empty if :obj:`return_scores` was disabled).")
.def_readonly("attention", &TranslationResult::attention,
"Attention matrix of each translation hypothesis (empty if :obj:`return_attention` was disabled).")
.def_readonly("logits", &TranslationResult::logits,
"Score of each translation hypothesis (empty if :obj:`return_logits_vocab` was disabled).")

.def("__repr__", [](const TranslationResult& result) {
return "TranslationResult(hypotheses=" + std::string(py::repr(py::cast(result.hypotheses)))
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
+ ", attention=" + std::string(py::repr(py::cast(result.attention)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ")";
})

Expand Down
4 changes: 4 additions & 0 deletions python/cpp/translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ namespace ctranslate2 {
size_t min_decoding_length,
bool use_vmap,
bool return_scores,
bool return_logits_vocab,
bool return_attention,
bool return_alternatives,
float min_alternative_expansion_prob,
Expand Down Expand Up @@ -172,6 +173,7 @@ namespace ctranslate2 {
options.use_vmap = use_vmap;
options.return_end_token = return_end_token;
options.return_scores = return_scores;
options.return_logits_vocab = return_logits_vocab;
options.return_attention = return_attention;
options.return_alternatives = return_alternatives;
options.min_alternative_expansion_prob = min_alternative_expansion_prob;
Expand Down Expand Up @@ -354,6 +356,7 @@ namespace ctranslate2 {
py::arg("min_decoding_length")=1,
py::arg("use_vmap")=false,
py::arg("return_scores")=false,
py::arg("return_logits_vocab")=false,
py::arg("return_attention")=false,
py::arg("return_alternatives")=false,
py::arg("min_alternative_expansion_prob")=0,
Expand Down Expand Up @@ -396,6 +399,7 @@ namespace ctranslate2 {
min_decoding_length: Minimum prediction length.
use_vmap: Use the vocabulary mapping file saved in this model
return_scores: Include the scores in the output.
return_logits_vocab: Include the log probs of each token in the output
return_attention: Include the attention vectors in the output.
return_alternatives: Return alternatives at the first unconstrained decoding position.
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.
Expand Down
4 changes: 4 additions & 0 deletions python/cpp/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace ctranslate2 {
size_t no_repeat_ngram_size,
size_t max_length,
bool return_scores,
bool return_logits_vocab,
bool return_no_speech_prob,
size_t max_initial_timestamp_index,
bool suppress_blank,
Expand All @@ -59,6 +60,7 @@ namespace ctranslate2 {
options.max_length = max_length;
options.num_hypotheses = num_hypotheses;
options.return_scores = return_scores;
options.return_logits_vocab = return_logits_vocab;
options.return_no_speech_prob = return_no_speech_prob;
options.max_initial_timestamp_index = max_initial_timestamp_index;
options.suppress_blank = suppress_blank;
Expand Down Expand Up @@ -247,6 +249,7 @@ namespace ctranslate2 {
py::arg("no_repeat_ngram_size")=0,
py::arg("max_length")=448,
py::arg("return_scores")=false,
py::arg("return_logits_vocab")=false,
py::arg("return_no_speech_prob")=false,
py::arg("max_initial_timestamp_index")=50,
py::arg("suppress_blank")=true,
Expand Down Expand Up @@ -276,6 +279,7 @@ namespace ctranslate2 {
(set 0 to disable).
max_length: Maximum generation length.
return_scores: Include the scores in the output.
return_logits_vocab: Include the log probs in the output
return_no_speech_prob: Include the probability of the no speech token in the
result.
max_initial_timestamp_index: Maximum index of the first predicted timestamp.
Expand Down
9 changes: 6 additions & 3 deletions python/tests/test_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,12 @@ def test_batch_translation(max_batch_size):
assert output[0].scores[0] < 0
assert not output[0].attention

expected_repr = "TranslationResult(hypotheses=%s, scores=%s, attention=[])" % (
output[0].hypotheses,
output[0].scores,
expected_repr = (
"TranslationResult(hypotheses=%s, scores=%s, attention=[], logits=[])"
% (
output[0].hypotheses,
output[0].scores,
)
)
assert repr(output[0]) == expected_repr

Expand Down
Loading
Loading