Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check the full vocab for grammar only if necessary #4306

Merged
merged 8 commits into from
Dec 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,12 @@ static void sampler_queue(
}
}

llama_token llama_sampling_sample(
static llama_token llama_sampling_sample_impl(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const int idx,
bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params;

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
Expand All @@ -173,8 +174,17 @@ llama_token llama_sampling_sample(

llama_token id = 0;

// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);

// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;

if (!is_resampling) {
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
}

// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
Expand Down Expand Up @@ -210,7 +220,8 @@ llama_token llama_sampling_sample(
}
}

if (ctx_sampling->grammar != NULL) {
// If we are in the resampling phase, apply grammar checks before sampling logic
if (is_resampling && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}

Expand Down Expand Up @@ -252,9 +263,40 @@ llama_token llama_sampling_sample(
}
}

if (ctx_sampling->grammar != NULL && !is_resampling) {
// Create an array with a single token data element for the sampled id
llama_token_data single_token_data = {id, logits[id], 0.0f};
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };

// Apply grammar constraints to the single token
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);

// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
Comment on lines +268 to +275
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we could probably expose something more straightforward to check this if we want. It would probably be like llama_grammar_accept_token but returning a bool instead of throwing.


// If the token is not valid according to the grammar, perform resampling
if (!is_valid) {
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());

// Restore logits from the copy
std::copy(original_logits.begin(), original_logits.end(), logits);

return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
}
}

return id;
}

llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
// Call the implementation function with is_resampling set to false by default
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}

void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
Expand Down
Loading