Skip to content

Commit

Permalink
grammar : check the full vocab only if necessary (opt) (ggerganov#4306)
Browse files Browse the repository at this point in the history
* Check the full vocab for grammar only if necessary

* Fix missing logit restoration step (?)

Does this matter, actually?

* Fix whitespace / formatting

* Adjust comment

* Didn't mean to push test gbnf

* Split sampling into the helper function (?)

And also revert the changes made to the header

* common : fix final newline

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
kalomaze and ggerganov authored Dec 23, 2023
1 parent e0a4002 commit b9ec82d
Showing 1 changed file with 45 additions and 3 deletions.
48 changes: 45 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,12 @@ static void sampler_queue(
}
}

llama_token llama_sampling_sample(
static llama_token llama_sampling_sample_impl(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const int idx,
bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params;

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
Expand All @@ -173,8 +174,17 @@ llama_token llama_sampling_sample(

llama_token id = 0;

// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);

// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;

if (!is_resampling) {
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
}

// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
Expand Down Expand Up @@ -210,7 +220,8 @@ llama_token llama_sampling_sample(
}
}

if (ctx_sampling->grammar != NULL) {
// If we are in the resampling phase, apply grammar checks before sampling logic
if (is_resampling && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}

Expand Down Expand Up @@ -252,9 +263,40 @@ llama_token llama_sampling_sample(
}
}

if (ctx_sampling->grammar != NULL && !is_resampling) {
// Create an array with a single token data element for the sampled id
llama_token_data single_token_data = {id, logits[id], 0.0f};
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };

// Apply grammar constraints to the single token
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);

// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;

// If the token is not valid according to the grammar, perform resampling
if (!is_valid) {
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());

// Restore logits from the copy
std::copy(original_logits.begin(), original_logits.end(), logits);

return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
}
}

return id;
}

llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
// Call the implementation function with is_resampling set to false by default
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}

void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
Expand Down

0 comments on commit b9ec82d

Please sign in to comment.