Skip to content

Prompt lookup decoding #4484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ models-mnt
/llama-bench
/llava-cli
/lookahead
/lookup
/main
/metal
/perplexity
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
BUILD_TARGETS = \
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead tests/test-c.o
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup tests/test-c.o

# Binaries only useful for tests
TEST_TARGETS = \
Expand Down Expand Up @@ -639,6 +639,9 @@ parallel: examples/parallel/parallel.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

ifdef LLAMA_METAL
metal: examples/metal/metal.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct gpt_params {
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
Expand Down Expand Up @@ -240,3 +240,4 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);

// Dump the KV cache view showing individual sequences in each cell (long output).
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);

1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ else()
add_subdirectory(simple)
add_subdirectory(speculative)
add_subdirectory(lookahead)
add_subdirectory(lookup)
add_subdirectory(train-text-from-scratch)
if (LLAMA_METAL)
add_subdirectory(metal)
Expand Down
5 changes: 5 additions & 0 deletions examples/lookup/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET lookup)
add_executable(${TARGET} lookup.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
13 changes: 13 additions & 0 deletions examples/lookup/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# llama.cpp/examples/lookup

Demonstration of Prompt Lookup Decoding

https://github.com/apoorvumang/prompt-lookup-decoding

The key parameters for lookup decoding are `ngram_min`, `ngram_max` and `n_draft`. The first two determine the size of the ngrams to search for in the prompt for a match. The latter specifies how many subsequent tokens to draft if a match is found.

More info:

https://github.com/ggerganov/llama.cpp/pull/4484
https://github.com/ggerganov/llama.cpp/issues/4226

230 changes: 230 additions & 0 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
#include "common.h"
#include "llama.h"

#include <cmath>
#include <cstdio>
#include <string>
#include <vector>

int main(int argc, char ** argv){
gpt_params params;

if (!gpt_params_parse(argc, argv, params)) {
return 1;
}

// max/min n-grams size to search for in prompt
const int ngram_max = 4;
const int ngram_min = 1;

// length of the candidate / draft sequence, if match is found
const int n_draft = params.n_draft;

const bool dump_kv_cache = params.dump_kv_cache;

#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("lookup", "log"));
LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv);
#endif // LOG_DISABLE_LOGS

// init llama.cpp
llama_backend_init(params.numa);

llama_model * model = NULL;
llama_context * ctx = NULL;

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);

// tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos tgt: %d\n", add_bos);

std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);

const int max_context_size = llama_n_ctx(ctx);
const int max_tokens_list_size = max_context_size - 4;

if ((int) inp.size() > max_tokens_list_size) {
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
return 1;
}

fprintf(stderr, "\n\n");

for (auto id : inp) {
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
}

fflush(stderr);

const int n_input = inp.size();

const auto t_enc_start = ggml_time_us();

llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));

const auto t_enc_end = ggml_time_us();

int n_predict = 0;
int n_drafted = 0;
int n_accept = 0;

int n_past = inp.size();

bool has_eos = false;

struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);

std::vector<llama_token> draft;

llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);

// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);

const auto t_dec_start = ggml_time_us();

while (true) {
// debug
if (dump_kv_cache) {
llama_kv_cache_view_update(ctx, &kvc_view);
dump_kv_cache_view_seqs(kvc_view, 40);
}

// print current draft sequence
LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());

int i_dft = 0;
while (true) {
// sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);

llama_sampling_accept(ctx_sampling, ctx, id, true);

const std::string token_str = llama_token_to_piece(ctx, id);

if (!params.use_color) {
printf("%s", token_str.c_str());
}

if (id == llama_token_eos(model)) {
has_eos = true;
}

++n_predict;

// check if the target token matches the draft
if (i_dft < (int) draft.size() && id == draft[i_dft]) {
LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str());
++n_accept;
++n_past;
++i_dft;
inp.push_back(id);

if (params.use_color) {
// color accepted draft token
printf("\033[34m%s\033[0m", token_str.c_str());
fflush(stdout);
}
continue;
}

if (params.use_color) {
printf("%s", token_str.c_str());
}
fflush(stdout);


LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());

draft.clear();
draft.push_back(id);
inp.push_back(id);
break;
}

if ((params.n_predict > 0 && n_predict > params.n_predict) || has_eos) {
break;
}

// KV cache management
// clean the cache of draft tokens that weren't accepted
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);

llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

// generate n_pred tokens through prompt lookup
auto prompt_lookup = [&]() -> void {
int inp_size = inp.size();
for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
const llama_token * ngram = &inp[inp_size - ngram_size];

for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
bool match = true;
for (int j = 0; j < ngram_size; ++j) {
if (inp[i + j] != ngram[j]) {
match = false;
break;
}
}

if (match) {
const int startIdx = i + ngram_size;
const int endIdx = startIdx + n_draft;
if (endIdx < inp_size) {
for (int j = startIdx; j < endIdx; ++j) {
LOG(" - draft candidate %d: %d\n", j, inp[j]);
draft.push_back(inp[j]);
llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
++n_drafted;
}
return;
}
}
}
}
return;
};

prompt_lookup();

llama_decode(ctx, batch_tgt);
++n_past;

draft.erase(draft.begin());
}

auto t_dec_end = ggml_time_us();

LOG_TEE("\n\n");

LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));

LOG_TEE("\n");
LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_drafted = %d\n", n_drafted);
LOG_TEE("n_accept = %d\n", n_accept);
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);

LOG_TEE("\ntarget:\n");
llama_print_timings(ctx);

llama_sampling_free(ctx_sampling);
llama_batch_free(batch_tgt);

llama_free(ctx);
llama_free_model(model);

llama_backend_free();

fprintf(stderr, "\n\n");

return 0;
}