Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lookahead : add example for lookahead decoding #4207

Merged
merged 9 commits into from
Nov 26, 2023
Prev Previous commit
Next Next commit
lookahead : initial working implementation
  • Loading branch information
ggerganov committed Nov 25, 2023
commit 61d039727a8460e369f41efb30a3bd9243555ff6
236 changes: 166 additions & 70 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include <vector>

struct seq_ngram {
bool active = false;
bool active = false;

llama_seq_id seq_id = -1;

std::vector<int> i_batch;

std::vector<llama_token> tokens;
};
Expand All @@ -34,9 +38,9 @@ int main(int argc, char ** argv) {
return 1;
}

const int W = 5; // lookahead window
const int N = 4; // n-gram size
const int G = 5; // max verification n-grams
const int W = 10; // lookahead window
const int N = 8; // n-gram size
const int G = 10; // max verification n-grams

const bool dump_kv_cache = params.dump_kv_cache;

Expand Down Expand Up @@ -89,7 +93,7 @@ int main(int argc, char ** argv) {
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));

for (int s = 0; s < W + G + 1; ++s) {
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
}

Expand All @@ -114,15 +118,18 @@ int main(int argc, char ** argv) {
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);

// verification n-grams
std::vector<seq_ngram> ngrams(G);
std::vector<seq_ngram> ngrams_cur(G);

// tokens for the past N - 1 Jacobi iterations
std::vector<llama_token> tokens_j_prev(W);
std::vector<std::vector<llama_token>> tokens_j(N - 1);
for (int j = 0; j < N - 1; j++) {
tokens_j[j].resize(W);

for (int i = 0; i < W; i++) {
tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
// initialize randomly from the prompt tokens
//tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
tokens_j[j][i] = 100 + i;
}
}

Expand Down Expand Up @@ -168,113 +175,202 @@ int main(int argc, char ** argv) {
{
llama_batch_clear(batch);

// current token - first token of the first level
llama_batch_add(batch, id, n_past, seq_id_all, true);

// verification n-grams - queue this here for less KV cache fragmentation
{
const int g_cur = ngrams_observed.cnt[id];

ngrams_cur.resize(g_cur);
for (int g = 0; g < g_cur; g++) {
ngrams_cur[g].active = true;
ngrams_cur[g].tokens.resize(N);
ngrams_cur[g].i_batch.resize(N);
ngrams_cur[g].seq_id = W + 1 + g;
ngrams_cur[g].i_batch[0] = 0;
ngrams_cur[g].tokens [0] = id;
}

for (int j = 0; j < N - 1; j++) {
for (int g = 0; g < g_cur; g++) {
const int idx = id*(N - 1)*G + g*(N - 1);

const llama_token t = ngrams_observed.tokens[idx + j];

ngrams_cur[g].tokens [j + 1] = t;
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;

llama_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
}
}
}

// fill the remaining W - 1 tokens for the first level
for (int i = 1; i < W; i++) {
llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
}

// fill the rest of the levels
for (int j = 1; j < N - 1; j++) {
for (int i = 0; i < W; i++) {
llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
}
}
}

// TODO: add verification n-grams
if (llama_decode(ctx, batch) != 0) {
fprintf(stderr, "\n\n%s: error: llama_decode failed - increase KV cache size\n", __func__);
return 1;
}

llama_decode(ctx, batch);
int seq_id_best = 0;

id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
for (int v = 0; v < N; ++v) {
int i_batch = 0;

llama_sampling_accept(ctx_sampling, ctx, id, true);
if (v > 0) {
for (int g = 0; g < (int) ngrams_cur.size(); g++) {
if (ngrams_cur[g].active) {
i_batch = ngrams_cur[g].i_batch[v];
seq_id_best = ngrams_cur[g].seq_id;
break;
}
}

{
const std::string token_str = llama_token_to_piece(ctx, id);
// no more matches
if (i_batch == 0) {
break;
}
}

printf("%s", token_str.c_str());
fflush(stdout);
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);

if (id == llama_token_eos(model)) {
has_eos = true;
}
}
llama_sampling_accept(ctx_sampling, ctx, id, true);

++n_predict;
++n_past;
{
const std::string token_str = llama_token_to_piece(ctx, id);

if (n_predict > params.n_predict || has_eos) {
break;
}
if (v == 0) {
printf("%s", token_str.c_str());
} else {
// print light cyan
printf("\033[0;96m%s\033[0m", token_str.c_str());
}
fflush(stdout);

if (id == llama_token_eos(model)) {
has_eos = true;
}

// print known n-grams starting with token id
if (1) {
if (ngrams_observed.cnt[id] > 0) {
printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str());
all.push_back(id);
}

for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
printf(" - ngram %2d: ", i);
++n_predict;
++n_past;

const int idx = id*(N - 1)*G + i*(N - 1);
if (n_predict > params.n_predict || has_eos) {
break;
}

for (int j = 0; j < N - 1; j++) {
const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
// verify across active n-grams
for (int g = 0; g < (int) ngrams_cur.size(); g++) {
if (ngrams_cur[g].active) {
if (v == N - 1) {
ngrams_cur[g].active = false;
} else {
if (id != ngrams_cur[g].tokens[v + 1]) {
ngrams_cur[g].active = false;
} else {
}
}
}
}

printf("%s", token_str.c_str());
// print known n-grams starting with token id
if (0) {
if (ngrams_observed.cnt[id] > 0) {
printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str());
}

printf("\n");
}
}
for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
printf(" - ngram %2d: ", i);

// update Jacobi tokens (or whatever these are called)
{
for (int i = 0; i < W; i++) {
tokens_j_prev[i] = tokens_j[0][i];
}
const int idx = id*(N - 1)*G + i*(N - 1);

for (int j = 0; j < N - 1; j++) {
const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);

for (int j = 0; j < N - 2; j++) {
tokens_j[j] = tokens_j[j + 1];
printf("%s", token_str.c_str());
}

printf("\n");
}
}

for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, W*(N - 2) + i);
// update Jacobi tokens (or whatever these are called)
{
for (int i = 0; i < W; i++) {
tokens_j_prev[i] = tokens_j[0][i];
}

for (int j = 0; j < N - 2; j++) {
tokens_j[j] = tokens_j[j + 1];
}

if (v == 0) {
// sample from the last level
for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
}
} else {
for (int i = 0; i < W; i++) {
// random init
//tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
tokens_j[N - 2][i] = tokens_j[0][i];
}
}
}
}

// update observed ngrams
{
// the first token of the n-gram is determined by the index in the container so it is not stored
std::vector<llama_token> ngram(N - 1);
// update observed ngrams
{
// the first token of the n-gram is determined by the index in the container so it is not stored
std::vector<llama_token> ngram(N - 1);

// n-gram generation
for (int f = 0; f < W; ++f) {
for (int j = 0; j < N - 1; ++j) {
ngram[j] = tokens_j[j][f];
};
// n-gram generation
// ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518
for (int f = 0; f < W; ++f) {
for (int j = 0; j < N - 1; ++j) {
ngram[j] = tokens_j[j][f];
};

const int ft = tokens_j_prev[f]; // first token of the n-gram
const int head = ngrams_observed.head[ft];
const int idx = ft*(N - 1)*G + head*(N - 1);
const int ft = tokens_j_prev[f]; // first token of the n-gram
const int head = ngrams_observed.head[ft];
const int idx = ft*(N - 1)*G + head*(N - 1);

for (int i = 0; i < N - 1; i++) {
ngrams_observed.tokens[idx + i] = ngram[i];
}
for (int i = 0; i < N - 1; i++) {
ngrams_observed.tokens[idx + i] = ngram[i];
}

ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1);
ngrams_observed.head[ft] = (head + 1) % G;
ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1);
ngrams_observed.head[ft] = (head + 1) % G;

ngrams_observed.n_total++;
ngrams_observed.n_total++;
}
}
}

// verification
// TODO
{
}

llama_kv_cache_seq_rm(ctx, -1, n_past, -1);

if (seq_id_best != 0) {
llama_kv_cache_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);

for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
}
}
}

auto t_dec_end = ggml_time_us();
Expand Down
Loading