From 66479457d4675fc336afe84abc2b4f024e4ca67e Mon Sep 17 00:00:00 2001 From: Minh-Thuc <46375464+minhthuc2502@users.noreply.github.com> Date: Mon, 26 Aug 2024 09:39:33 +0200 Subject: [PATCH] Add log probs for all tokens (#1755) * add log probs * fix compilation * fix compilation * fix test * fix black * return logits * fix compilation * fix test * last clean --- include/ctranslate2/decoding.h | 8 +++- include/ctranslate2/generation.h | 9 ++++- include/ctranslate2/models/whisper.h | 4 ++ include/ctranslate2/translation.h | 7 +++- python/cpp/generation_result.cc | 10 ++++- python/cpp/generator.cc | 4 ++ python/cpp/storage_view.cc | 6 +++ python/cpp/translation_result.cc | 3 ++ python/cpp/translator.cc | 4 ++ python/cpp/whisper.cc | 4 ++ python/tests/test_translator.py | 9 +++-- src/decoding.cc | 57 +++++++++++++++++++++++++++- src/layers/attention.cc | 1 - src/layers/attention_layer.cc | 1 - src/models/language_model.cc | 2 + src/models/sequence_to_sequence.cc | 4 +- src/models/whisper.cc | 2 + 17 files changed, 122 insertions(+), 13 deletions(-) diff --git a/include/ctranslate2/decoding.h b/include/ctranslate2/decoding.h index 436280b46..5c1d316dc 100644 --- a/include/ctranslate2/decoding.h +++ b/include/ctranslate2/decoding.h @@ -15,6 +15,7 @@ namespace ctranslate2 { std::vector> hypotheses; std::vector scores; std::vector>> attention; + std::vector> logits_vocab; }; struct DecodingStepResult { @@ -22,7 +23,8 @@ namespace ctranslate2 { size_t batch_id; size_t token_id; size_t hypothesis_id; - std::optional log_prob; + std::optional score; + std::optional logits; bool is_last = false; }; @@ -41,6 +43,7 @@ namespace ctranslate2 { const dim_t min_length, const bool return_scores = false, const bool return_attention = false, + const bool return_logits_vocab = true, const bool return_prefix = true, const size_t num_hypotheses = 1, const bool include_eos_in_hypotheses = true, @@ -67,6 +70,7 @@ namespace ctranslate2 { const dim_t min_length, const bool return_scores = false, const bool return_attention = false, + const bool return_logits_vocab = true, const bool return_prefix = true, const size_t num_hypotheses = 1, const bool include_eos_in_hypotheses = true, @@ -118,6 +122,7 @@ namespace ctranslate2 { const dim_t min_length, const bool return_scores = false, const bool return_attention = false, + const bool return_logits_vocab = true, const bool return_prefix = true, const size_t num_hypotheses = 1, const bool include_eos_in_hypotheses = true, @@ -149,6 +154,7 @@ namespace ctranslate2 { bool include_eos_in_hypotheses = true; bool return_scores = false; bool return_attention = false; + bool return_logits_vocab = false; bool return_alternatives = false; bool return_prefix = true; float min_alternative_expansion_prob = 0; diff --git a/include/ctranslate2/generation.h b/include/ctranslate2/generation.h index f09aeef45..bd76146ff 100644 --- a/include/ctranslate2/generation.h +++ b/include/ctranslate2/generation.h @@ -53,6 +53,8 @@ namespace ctranslate2 { // Include scores in the result. bool return_scores = false; + // Include log probs of each token in the result + bool return_logits_vocab = false; // Return alternatives at the first unconstrained decoding position. This is typically // used with a prefix to provide alternatives at a specifc location. @@ -79,6 +81,7 @@ namespace ctranslate2 { std::vector> sequences; std::vector> sequences_ids; std::vector scores; + std::vector> logits; size_t num_sequences() const { return sequences.size(); @@ -95,7 +98,8 @@ namespace ctranslate2 { size_t token_id; size_t hypothesis_id; std::string token; - std::optional log_prob; + std::optional score; + std::optional logits; bool is_last; GenerationStepResult() = default; @@ -105,7 +109,8 @@ namespace ctranslate2 { , token_id(result.token_id) , hypothesis_id(result.hypothesis_id) , token(vocabulary.to_token(result.token_id)) - , log_prob(result.log_prob) + , score(result.score) + , logits(result.logits) , is_last(result.is_last) { } diff --git a/include/ctranslate2/models/whisper.h b/include/ctranslate2/models/whisper.h index 7ade2bd20..e9818cc4e 100644 --- a/include/ctranslate2/models/whisper.h +++ b/include/ctranslate2/models/whisper.h @@ -41,6 +41,9 @@ namespace ctranslate2 { // Include scores in the result. bool return_scores = false; + // Include log probs of each token in the result + bool return_logits_vocab = false; + // Include the probability of the no speech token in the result. bool return_no_speech_prob = false; @@ -59,6 +62,7 @@ namespace ctranslate2 { std::vector> sequences; std::vector> sequences_ids; std::vector scores; + std::vector> logits; float no_speech_prob = 0; size_t num_sequences() const { diff --git a/include/ctranslate2/translation.h b/include/ctranslate2/translation.h index 8d2ec943a..8e8222d3a 100644 --- a/include/ctranslate2/translation.h +++ b/include/ctranslate2/translation.h @@ -67,6 +67,8 @@ namespace ctranslate2 { bool return_scores = false; // Store attention vectors in the TranslationResult class. bool return_attention = false; + // Store log probs matrix in the TranslationResult class. + bool return_logits_vocab = false; // Return alternatives at the first unconstrained decoding position. This is typically // used with a target prefix to provide alternatives at a specifc location in the @@ -87,6 +89,7 @@ namespace ctranslate2 { std::vector> hypotheses; std::vector scores; std::vector>> attention; + std::vector> logits; TranslationResult(std::vector> hypotheses_) : hypotheses(std::move(hypotheses_)) @@ -95,10 +98,12 @@ namespace ctranslate2 { TranslationResult(std::vector> hypotheses_, std::vector scores_, - std::vector>> attention_) + std::vector>> attention_, + std::vector> logits_) : hypotheses(std::move(hypotheses_)) , scores(std::move(scores_)) , attention(std::move(attention_)) + , logits(std::move(logits_)) { } diff --git a/python/cpp/generation_result.cc b/python/cpp/generation_result.cc index 3d7685f6c..f2d500192 100644 --- a/python/cpp/generation_result.cc +++ b/python/cpp/generation_result.cc @@ -21,8 +21,10 @@ namespace ctranslate2 { "Index of the hypothesis in the batch.") .def_readonly("token", &GenerationStepResult::token, "String value of the generated token.") - .def_readonly("log_prob", &GenerationStepResult::log_prob, + .def_readonly("log_prob", &GenerationStepResult::score, "Log probability of the token (``None`` if :obj:`return_log_prob` was disabled).") + .def_readonly("logits", &GenerationStepResult::logits, + "Log probability on the vocab of all tokens.") .def_readonly("is_last", &GenerationStepResult::is_last, "Whether this step is the last decoding step for this batch.") @@ -32,7 +34,8 @@ namespace ctranslate2 { + ", token_id=" + std::string(py::repr(py::cast(result.token_id))) + ", hypothesis_id=" + std::string(py::repr(py::cast(result.hypothesis_id))) + ", token=" + std::string(py::repr(py::cast(result.token))) - + ", log_prob=" + std::string(py::repr(py::cast(result.log_prob))) + + ", log_prob=" + std::string(py::repr(py::cast(result.score))) + + ", logits=" + std::string(py::repr(py::cast(result.logits))) + ", is_last=" + std::string(py::repr(py::cast(result.is_last))) + ")"; }) @@ -46,11 +49,14 @@ namespace ctranslate2 { "Generated sequences of token IDs.") .def_readonly("scores", &GenerationResult::scores, "Score of each sequence (empty if :obj:`return_scores` was disabled).") + .def_readonly("logits", &GenerationResult::logits, + "Score of each sequence (empty if :obj:`return_logits_vocab` was disabled).") .def("__repr__", [](const GenerationResult& result) { return "GenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences))) + ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids))) + ", scores=" + std::string(py::repr(py::cast(result.scores))) + + ", logits=" + std::string(py::repr(py::cast(result.logits))) + ")"; }) ; diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc index c6b19d1cf..c09befe2b 100644 --- a/python/cpp/generator.cc +++ b/python/cpp/generator.cc @@ -33,6 +33,7 @@ namespace ctranslate2 { bool cache_static_prompt, bool include_prompt_in_result, bool return_scores, + bool return_logits_vocab, bool return_alternatives, float min_alternative_expansion_prob, size_t sampling_topk, @@ -58,6 +59,7 @@ namespace ctranslate2 { options.num_hypotheses = num_hypotheses; options.return_end_token = return_end_token; options.return_scores = return_scores; + options.return_logits_vocab = return_logits_vocab; options.return_alternatives = return_alternatives; options.cache_static_prompt = cache_static_prompt; options.include_prompt_in_result = include_prompt_in_result; @@ -203,6 +205,7 @@ namespace ctranslate2 { py::arg("cache_static_prompt")=true, py::arg("include_prompt_in_result")=true, py::arg("return_scores")=false, + py::arg("return_logits_vocab")=false, py::arg("return_alternatives")=false, py::arg("min_alternative_expansion_prob")=0, py::arg("sampling_topk")=1, @@ -260,6 +263,7 @@ namespace ctranslate2 { reuse it for future generations using the same static prompt. include_prompt_in_result: Include the :obj:`start_tokens` in the result. return_scores: Include the scores in the output. + return_logits_vocab: Include log probs for each token in the output return_alternatives: Return alternatives at the first unconstrained decoding position. min_alternative_expansion_prob: Minimum initial probability to expand an alternative. sampling_topk: Randomly sample predictions from the top K candidates. diff --git a/python/cpp/storage_view.cc b/python/cpp/storage_view.cc index 7c1f14ec2..56c95f9ca 100644 --- a/python/cpp/storage_view.cc +++ b/python/cpp/storage_view.cc @@ -192,6 +192,12 @@ namespace ctranslate2 { return stream.str(); }) + .def("__repr__", [](const StorageView& view) { + std::ostringstream stream; + stream << view; + return stream.str(); + }) + .def("to", [](const StorageView& view, DataType dtype) { ScopedDeviceSetter device_setter(view.device(), view.device_index()); diff --git a/python/cpp/translation_result.cc b/python/cpp/translation_result.cc index 3b8a0790b..fa7d70f4d 100644 --- a/python/cpp/translation_result.cc +++ b/python/cpp/translation_result.cc @@ -16,11 +16,14 @@ namespace ctranslate2 { "Score of each translation hypothesis (empty if :obj:`return_scores` was disabled).") .def_readonly("attention", &TranslationResult::attention, "Attention matrix of each translation hypothesis (empty if :obj:`return_attention` was disabled).") + .def_readonly("logits", &TranslationResult::logits, + "Score of each translation hypothesis (empty if :obj:`return_logits_vocab` was disabled).") .def("__repr__", [](const TranslationResult& result) { return "TranslationResult(hypotheses=" + std::string(py::repr(py::cast(result.hypotheses))) + ", scores=" + std::string(py::repr(py::cast(result.scores))) + ", attention=" + std::string(py::repr(py::cast(result.attention))) + + ", logits=" + std::string(py::repr(py::cast(result.logits))) + ")"; }) diff --git a/python/cpp/translator.cc b/python/cpp/translator.cc index 52902b986..319b524cc 100644 --- a/python/cpp/translator.cc +++ b/python/cpp/translator.cc @@ -141,6 +141,7 @@ namespace ctranslate2 { size_t min_decoding_length, bool use_vmap, bool return_scores, + bool return_logits_vocab, bool return_attention, bool return_alternatives, float min_alternative_expansion_prob, @@ -172,6 +173,7 @@ namespace ctranslate2 { options.use_vmap = use_vmap; options.return_end_token = return_end_token; options.return_scores = return_scores; + options.return_logits_vocab = return_logits_vocab; options.return_attention = return_attention; options.return_alternatives = return_alternatives; options.min_alternative_expansion_prob = min_alternative_expansion_prob; @@ -354,6 +356,7 @@ namespace ctranslate2 { py::arg("min_decoding_length")=1, py::arg("use_vmap")=false, py::arg("return_scores")=false, + py::arg("return_logits_vocab")=false, py::arg("return_attention")=false, py::arg("return_alternatives")=false, py::arg("min_alternative_expansion_prob")=0, @@ -396,6 +399,7 @@ namespace ctranslate2 { min_decoding_length: Minimum prediction length. use_vmap: Use the vocabulary mapping file saved in this model return_scores: Include the scores in the output. + return_logits_vocab: Include the log probs of each token in the output return_attention: Include the attention vectors in the output. return_alternatives: Return alternatives at the first unconstrained decoding position. min_alternative_expansion_prob: Minimum initial probability to expand an alternative. diff --git a/python/cpp/whisper.cc b/python/cpp/whisper.cc index c9463b64f..d0156c8c1 100644 --- a/python/cpp/whisper.cc +++ b/python/cpp/whisper.cc @@ -40,6 +40,7 @@ namespace ctranslate2 { size_t no_repeat_ngram_size, size_t max_length, bool return_scores, + bool return_logits_vocab, bool return_no_speech_prob, size_t max_initial_timestamp_index, bool suppress_blank, @@ -59,6 +60,7 @@ namespace ctranslate2 { options.max_length = max_length; options.num_hypotheses = num_hypotheses; options.return_scores = return_scores; + options.return_logits_vocab = return_logits_vocab; options.return_no_speech_prob = return_no_speech_prob; options.max_initial_timestamp_index = max_initial_timestamp_index; options.suppress_blank = suppress_blank; @@ -247,6 +249,7 @@ namespace ctranslate2 { py::arg("no_repeat_ngram_size")=0, py::arg("max_length")=448, py::arg("return_scores")=false, + py::arg("return_logits_vocab")=false, py::arg("return_no_speech_prob")=false, py::arg("max_initial_timestamp_index")=50, py::arg("suppress_blank")=true, @@ -276,6 +279,7 @@ namespace ctranslate2 { (set 0 to disable). max_length: Maximum generation length. return_scores: Include the scores in the output. + return_logits_vocab: Include the log probs in the output return_no_speech_prob: Include the probability of the no speech token in the result. max_initial_timestamp_index: Maximum index of the first predicted timestamp. diff --git a/python/tests/test_translator.py b/python/tests/test_translator.py index c64189226..f76b78c31 100644 --- a/python/tests/test_translator.py +++ b/python/tests/test_translator.py @@ -111,9 +111,12 @@ def test_batch_translation(max_batch_size): assert output[0].scores[0] < 0 assert not output[0].attention - expected_repr = "TranslationResult(hypotheses=%s, scores=%s, attention=[])" % ( - output[0].hypotheses, - output[0].scores, + expected_repr = ( + "TranslationResult(hypotheses=%s, scores=%s, attention=[], logits=[])" + % ( + output[0].hypotheses, + output[0].scores, + ) ) assert repr(output[0]) == expected_repr diff --git a/src/decoding.cc b/src/decoding.cc index 418389e2c..55a9d7844 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -157,6 +157,22 @@ namespace ctranslate2 { return attention; } + static std::vector build_logits(const StorageView& history, + const dim_t batch) { + if (!history) + return {}; + std::vector logits; + logits.reserve(batch); + for (dim_t t = 0; t < batch; ++t) { + ops::Slide slide(0, t, 1); + StorageView tmp(history.dtype(), history.device()); + slide(history, tmp); + logits.emplace_back(std::move(tmp.squeeze(0))); + } + + return logits; + } + static float compute_coverage_penalty(const std::vector>& attention, const float beta) { float penalty = 0; @@ -409,6 +425,7 @@ namespace ctranslate2 { const dim_t min_length, const bool return_scores, const bool return_attention, + const bool return_logits_vocab, const bool return_prefix, const size_t num_hypotheses, const bool include_eos_in_hypotheses, @@ -501,6 +518,9 @@ namespace ctranslate2 { } disable_tokens.apply(); + std::vector logits_vec; + if (return_logits_vocab) + logits_vec = build_logits(logits, cur_batch_size); StorageView log_probs(dtype, device); if (bias_towards_prefix) { @@ -582,6 +602,11 @@ namespace ctranslate2 { auto& result = results[batch_id]; dim_t secondary_candidates_offset = _beam_size; + if (return_logits_vocab) { + results[batch_id].logits_vocab.resize(1); + results[batch_id].logits_vocab[0].emplace_back(std::move(logits_vec[i])); + } + for (dim_t k = 0; k < _beam_size; ++k) { const size_t last_id = topk_ids.at({i, k}); dim_t next_beam_id = k; @@ -705,6 +730,7 @@ namespace ctranslate2 { const dim_t min_length, const bool return_scores, const bool return_attention, + const bool return_logits_vocab, const bool return_prefix, const size_t num_hypotheses, const bool include_eos_in_hypotheses, @@ -750,6 +776,7 @@ namespace ctranslate2 { min_length, /*return_scores=*/true, return_attention, + return_logits_vocab, return_prefix, /*num_hypotheses=*/1, include_eos_in_hypotheses, @@ -766,6 +793,8 @@ namespace ctranslate2 { final_result.scores.emplace_back(result.scores[0]); if (return_attention) final_result.attention.emplace_back(std::move(result.attention[0])); + if (return_logits_vocab) + final_result.logits_vocab.emplace_back(std::move(result.logits_vocab[0])); } for (auto& result : final_results) @@ -826,6 +855,12 @@ namespace ctranslate2 { disable_tokens.apply(); + std::vector logits_vec; + StorageView logits_orig(dtype, device); + if (return_logits_vocab) { + logits_vec = build_logits(logits, logits.dim(0)); + logits_orig.copy_from(logits); + } // Compute log probs only if required. StorageView log_probs(dtype, device); if (return_scores) @@ -857,6 +892,11 @@ namespace ctranslate2 { const dim_t prefix_length = prefix_ids ? prefix_ids->at(batch_id).size() : 0; const float score = best_probs.scalar_at({i, 0}); + if (return_logits_vocab) { + results[batch_id].logits_vocab.resize(1); + results[batch_id].logits_vocab[0].emplace_back(std::move(logits_vec[i])); + } + if ((!is_eos(word_id, end_ids) || include_eos_in_hypotheses) && (return_prefix || step >= prefix_length)) { results[batch_id].hypotheses[0].push_back(word_id); @@ -880,7 +920,9 @@ namespace ctranslate2 { step_result.hypothesis_id = 0; step_result.is_last = is_finished; if (return_scores) - step_result.log_prob = score; + step_result.score = score; + if (return_logits_vocab) + step_result.logits = std::move(logits_orig); if (_callback(std::move(step_result))) { is_finished = true; } @@ -1078,6 +1120,8 @@ namespace ctranslate2 { result.scores.resize(options.num_hypotheses, 0); if (options.return_attention) result.attention.resize(options.num_hypotheses); + if (options.return_logits_vocab) + result.logits_vocab.resize(options.num_hypotheses); if (start_tokens.empty()) throw std::invalid_argument("One input has no decoder start token"); @@ -1140,6 +1184,7 @@ namespace ctranslate2 { /*min_length=*/1, /*return_scores=*/true, options.return_attention, + options.return_logits_vocab, options.return_prefix, options.num_hypotheses, options.include_eos_in_hypotheses, @@ -1158,6 +1203,8 @@ namespace ctranslate2 { result.attention[i].emplace_back(std::move(expansion_result.attention[i].back())); if (options.return_scores) result.scores[i] = expansion_result.scores[i]; + if (options.return_logits_vocab) + result.logits_vocab[i].emplace_back(std::move(expansion_result.logits_vocab[i].back())); // The next input is the words we just expanded. start_ids.push_back(result.hypotheses[i].back()); @@ -1201,6 +1248,7 @@ namespace ctranslate2 { std::max(min_length - start_step, dim_t(0)), options.return_scores, options.return_attention, + options.return_logits_vocab, options.return_prefix, /*num_hypotheses=*/1, options.include_eos_in_hypotheses, @@ -1214,6 +1262,12 @@ namespace ctranslate2 { result.scores[i] += suffix.scores[0]; } + if (options.return_logits_vocab) { + result.logits_vocab[i].insert(result.logits_vocab[i].end(), + std::make_move_iterator(suffix.logits_vocab[0].begin()), + std::make_move_iterator(suffix.logits_vocab[0].end())); + } + if (options.return_attention) result.attention[i].insert(result.attention[i].end(), std::make_move_iterator(suffix.attention[0].begin()), @@ -1293,6 +1347,7 @@ namespace ctranslate2 { options.min_length, options.return_scores, options.return_attention, + options.return_logits_vocab, options.return_prefix, options.num_hypotheses, options.include_eos_in_hypotheses, diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 18e2710f7..24ffffdc8 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -462,7 +462,6 @@ namespace ctranslate2 { } else { combine_heads(context, _num_heads, queries_padder, beam_size); } - _linear.back()(context, output); if (_tensor_parallel) { diff --git a/src/layers/attention_layer.cc b/src/layers/attention_layer.cc index d6a823024..c9ae67409 100644 --- a/src/layers/attention_layer.cc +++ b/src/layers/attention_layer.cc @@ -7,7 +7,6 @@ #include "dispatch.h" #include "cpu/parallel.h" -#include namespace ctranslate2 { namespace layers { diff --git a/src/models/language_model.cc b/src/models/language_model.cc index 01ae7c8a4..5a23fa35a 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -165,6 +165,7 @@ namespace ctranslate2 { decoding_options.sampling_temperature = options.sampling_temperature; decoding_options.num_hypotheses = options.num_hypotheses; decoding_options.return_scores = options.return_scores; + decoding_options.return_logits_vocab = options.return_logits_vocab; decoding_options.return_alternatives = options.return_alternatives; decoding_options.min_alternative_expansion_prob = options.min_alternative_expansion_prob; decoding_options.disable_sequences = vocabulary.to_ids(options.suppress_sequences, @@ -268,6 +269,7 @@ namespace ctranslate2 { final_result.sequences = vocabulary.to_tokens(result.hypotheses); final_result.sequences_ids = std::move(result.hypotheses); final_result.scores = std::move(result.scores); + final_result.logits = std::move(result.logits_vocab); final_results.emplace_back(std::move(final_result)); } diff --git a/src/models/sequence_to_sequence.cc b/src/models/sequence_to_sequence.cc index ed4bb214b..67e0facc4 100644 --- a/src/models/sequence_to_sequence.cc +++ b/src/models/sequence_to_sequence.cc @@ -348,6 +348,7 @@ namespace ctranslate2 { decoding_options.sampling_temperature = options.sampling_temperature; decoding_options.num_hypotheses = options.num_hypotheses; decoding_options.return_scores = options.return_scores; + decoding_options.return_logits_vocab = options.return_logits_vocab; decoding_options.return_attention = options.return_attention || options.replace_unknowns; decoding_options.return_alternatives = options.return_alternatives; decoding_options.min_alternative_expansion_prob = options.min_alternative_expansion_prob; @@ -423,7 +424,8 @@ namespace ctranslate2 { final_results.emplace_back(std::move(hypotheses), std::move(result.scores), - std::move(result.attention)); + std::move(result.attention), + std::move(result.logits_vocab)); } return final_results; diff --git a/src/models/whisper.cc b/src/models/whisper.cc index 349279240..7cdf2dc5b 100644 --- a/src/models/whisper.cc +++ b/src/models/whisper.cc @@ -302,6 +302,7 @@ namespace ctranslate2 { decoding_options.sampling_temperature = options.sampling_temperature; decoding_options.num_hypotheses = options.num_hypotheses; decoding_options.return_scores = options.return_scores; + decoding_options.return_logits_vocab = options.return_logits_vocab; decoding_options.include_eos_in_hypotheses = false; for (const auto& id : options.suppress_tokens) { @@ -356,6 +357,7 @@ namespace ctranslate2 { final_result.sequences = vocabulary.to_tokens(result.hypotheses); final_result.sequences_ids = std::move(result.hypotheses); final_result.scores = std::move(result.scores); + final_result.logits = std::move(result.logits_vocab); if (options.return_no_speech_prob) final_result.no_speech_prob = no_speech_probs[i];