Skip to content

Commit

Permalink
async speculation
Browse files Browse the repository at this point in the history
  • Loading branch information
okuvshynov committed Apr 23, 2024
1 parent 4e96a81 commit c8d446d
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 3 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,10 @@ simple: examples/simple/simple.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

async_spec: examples/async_spec/async_spec.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tokenize: examples/tokenize/tokenize.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
Expand Down
8 changes: 8 additions & 0 deletions examples/async_spec/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
set(TARGET async_spec)
add_executable(${TARGET} async_spec.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()
338 changes: 338 additions & 0 deletions examples/async_spec/async_spec.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
#include "common.h"
#include "llama.h"

#include <cmath>
#include <cstdio>
#include <mutex>
#include <string>
#include <thread>
#include <vector>

// debug visualization utils for tokens which were:
// - matched - generated by both models
// - rejected - produced by speculation model and did not match
// - no_matched - produced by main model (and accepted) but not matched with speculation model
// generated by both draft and main models
static void dbg_color(const std::string& s, const std::string& fg, size_t bg_index) {
static const std::vector<std::string> kBackgrounds = {"\033[40m", "\033[48m"};
static const std::string kReset = "\033[0m";
std::cout
<< kBackgrounds[bg_index % kBackgrounds.size()]
<< fg << s << kReset << std::flush;
}

static void dbg_accepted(const std::string& accepted, size_t bg_index) {
static const std::string kGreen = "\033[32m";
dbg_color(accepted, kGreen, bg_index);
}

// something main model generated which was accepted but not matched by
static void dbg_not_matched(const std::string& accepted, size_t bg_index) {
dbg_color(accepted, "", bg_index);
}

static void dbg_rejected(const std::string& rejected, size_t bg_index) {
static const std::string kRed = "\033[31m";
dbg_color(rejected, kRed, bg_index);
}

// shared data between main and speculation process
struct linear_speculative_context {
std::vector<llama_token> speculation;
std::mutex mtx;
bool done;
};

// greedy sampling
static std::vector<llama_token> greedy_tokens(llama_model* model, llama_context* ctx, int from_idx, int to_idx) {
auto n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.resize(n_vocab);
std::vector<llama_token> res;

for (int idx = from_idx; idx < to_idx; idx++) {
auto * logits = llama_get_logits_ith(ctx, idx);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates[token_id] = llama_token_data{ token_id, logits[token_id], 0.0f };
}

llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

// sample the most likely token
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
res.push_back(new_token_id);
}
return res;
}

static int main_loop(
llama_model *model,
linear_speculative_context *spec_ctx,
llama_context *ctx,
std::vector<llama_token> tokens_list /* copy here */) {
const int n_len = 1024;

llama_batch batch = llama_batch_init(1024, 0, 1);

// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); i++) {
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;

if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}

// how many tokens are currently accepted
int n_cur = batch.n_tokens;

std::vector<llama_token> input_seq, next_tokens;
input_seq.push_back(tokens_list.back());

int logits_from = n_cur - 1;
int logits_to = n_cur;
size_t bg_index = 0;

while (n_cur <= n_len) {
bg_index++;
next_tokens = greedy_tokens(model, ctx, logits_from, logits_to);
if (next_tokens.size() != input_seq.size()) {
fprintf(stderr, "invalid next tokens\n");
return 1;
}

// this is where next_tokens start
int next_tokens_pos = n_cur;
// we always accept at least one new token
n_cur += 1;
for (size_t i = 0; i + 1 < input_seq.size(); i++) {
if (next_tokens[i] == input_seq[i + 1]) {
n_cur += 1;
} else {
// reject. next_tokens[i] is the last 'correct' one.
next_tokens.erase(next_tokens.begin() + i + 1, next_tokens.end());
break;
}
}
// empty the main model cache
llama_kv_cache_seq_rm(ctx, 0, n_cur - 1, -1);

bool done = false;
for (llama_token new_token_id: next_tokens) {
if (new_token_id == llama_token_eos(model)) {
done = true;
}
}
if (n_cur >= n_len || done) {
break;
}

// CRITICAL SECTION -- reconcile main and speculative
{
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
auto& spec = spec_ctx->speculation;
size_t n_match = 0;
for (size_t i = 0; i < next_tokens.size() && i + next_tokens_pos < spec.size(); i++) {
if (next_tokens[i] == spec[i + next_tokens_pos]) {
n_match++;
} else {
break;
}
}

std::string accepted = "";
// Write accepted/rejected/not matched
// this is slow and inefficient but for short strings doesn't matter
for (size_t i = next_tokens_pos; i < next_tokens_pos + n_match; i++) {
accepted += llama_token_to_piece(ctx, spec[i]);
}
dbg_accepted(accepted, bg_index);
if (n_match != next_tokens.size()) {
std::string rejected = "";
for (size_t i = next_tokens_pos + n_match; i < spec.size(); i++) {
rejected += llama_token_to_piece(ctx, spec[i]);
}
dbg_rejected(rejected, bg_index);
// need to modify speculation
spec.erase(spec.begin() + next_tokens_pos, spec.end());
for (const auto tok: next_tokens) {
spec.push_back(tok);
}
std::string not_matched = "";
for (size_t i = n_match; i < next_tokens.size(); i++) {
not_matched += llama_token_to_piece(ctx, next_tokens[i]);
}
dbg_not_matched(not_matched, bg_index);
}

input_seq.assign(spec.begin() + n_cur - 1, spec.end());
}

llama_batch_clear(batch);
for (size_t i = 0; i < input_seq.size(); i++) {
llama_batch_add(batch, input_seq[i], n_cur - 1 + i, { 0 }, true);
}
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
logits_from = 0;
logits_to = input_seq.size();
}

for (size_t i = 0; i < next_tokens.size(); i++) {
dbg_not_matched(llama_token_to_piece(ctx, next_tokens[i]), bg_index);
}
std::cout << std::endl << std::endl;
{
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
spec_ctx->done = true;
}

llama_batch_free(batch);
return 0;
}

static int draft_loop(
llama_model *model,
linear_speculative_context *spec_ctx,
llama_context *ctx,
std::vector<llama_token> tokens_list /* copy here */) {

llama_batch batch = llama_batch_init(512, 0, 1);

// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); i++) {
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;

if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}

int logit_idx = batch.n_tokens - 1;
std::vector<llama_token> local_spec = tokens_list;
size_t match_len;

while (true) {
auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1);
if (next_tokens.size() != 1) {
fprintf(stderr, "invalid next tokens\n");
return 1;
}

local_spec.push_back(next_tokens[0]);

{
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
if (spec_ctx->done) {
break;
}
auto& spec = spec_ctx->speculation;
bool match = true;
match_len = local_spec.size() - 1;
for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++) {
if (spec[i] != local_spec[i]) {
match = false;
match_len = i;
llama_kv_cache_seq_rm(ctx, 0, i, -1);
break;
}
}
if (match) {
spec = local_spec;
} else {
local_spec = spec;
}
}

llama_batch_clear(batch);
// TODO theoretically this can be empty?
for (size_t i = match_len; i < local_spec.size(); i++) {
llama_batch_add(batch, local_spec[i], i, { 0 }, true);
}

logit_idx = batch.n_tokens - 1;

// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
}

llama_batch_free(batch);
return 0;
}

int main(int argc, char ** argv) {
gpt_params params;

llama_backend_init();
llama_numa_init(params.numa);

// init context params
llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234;
ctx_params.n_ctx = 2048;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

// Init main model and context
if (argc >= 2) {
params.model = argv[1];
}
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = 99;
llama_model *main_model = llama_load_model_from_file(params.model.c_str(), model_params);
llama_context *main_ctx = llama_new_context_with_model(main_model, ctx_params);

// Init draft model
if (argc >= 3) {
params.model = argv[2];
}
model_params.n_gpu_layers = 0;
llama_model *draft_model = llama_load_model_from_file(params.model.c_str(), model_params);
llama_context *draft_ctx = llama_new_context_with_model(draft_model, ctx_params);

// Print & tokenize prompt
// tokenizer should be the same and prompt tokens should be the same
if (argc >= 4) {
params.prompt = argv[3];
}
if (params.prompt.empty()) {
params.prompt = "What's the difference between instruction cache and data cache?";
}
dbg_not_matched(params.prompt, 0);
std::vector<llama_token> tokens_list = llama_tokenize(main_ctx, params.prompt, true);

// Init shared speculative context
linear_speculative_context spec_ctx;
spec_ctx.speculation = tokens_list;
spec_ctx.done = false;

const auto t_main_start = ggml_time_us();
std::thread t_cpu(draft_loop, draft_model, &spec_ctx, draft_ctx, tokens_list);
std::thread t_gpu(main_loop, main_model, &spec_ctx, main_ctx, tokens_list);
t_gpu.join();
t_cpu.join();
const auto t_main_end = ggml_time_us();

printf("Total time: %.3lf\n", (t_main_end - t_main_start) / 1000000.0);

llama_free_model(main_model);
llama_free(main_ctx);
llama_free_model(draft_model);
llama_free(draft_ctx);
llama_backend_free();

return 0;
}
1 change: 1 addition & 0 deletions examples/async_spec/in.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nPlease give a detailed description of concurrency and parallelism in Python. Provide some examples.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n
5 changes: 2 additions & 3 deletions examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ int main(int argc, char ** argv) {
}

// total length of the sequence including the prompt
const int n_len = 32;
const int n_len = 1024;

// init LLM

Expand Down Expand Up @@ -103,6 +103,7 @@ int main(int argc, char ** argv) {
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;

const auto t_main_start = ggml_time_us();
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
Expand All @@ -113,8 +114,6 @@ int main(int argc, char ** argv) {
int n_cur = batch.n_tokens;
int n_decode = 0;

const auto t_main_start = ggml_time_us();

while (n_cur <= n_len) {
// sample the next token
{
Expand Down

0 comments on commit c8d446d

Please sign in to comment.