-
Notifications
You must be signed in to change notification settings - Fork 10.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4e96a81
commit c8d446d
Showing
5 changed files
with
353 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
set(TARGET async_spec) | ||
add_executable(${TARGET} async_spec.cpp) | ||
install(TARGETS ${TARGET} RUNTIME) | ||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) | ||
target_compile_features(${TARGET} PRIVATE cxx_std_11) | ||
if(TARGET BUILD_INFO) | ||
add_dependencies(${TARGET} BUILD_INFO) | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,338 @@ | ||
#include "common.h" | ||
#include "llama.h" | ||
|
||
#include <cmath> | ||
#include <cstdio> | ||
#include <mutex> | ||
#include <string> | ||
#include <thread> | ||
#include <vector> | ||
|
||
// debug visualization utils for tokens which were: | ||
// - matched - generated by both models | ||
// - rejected - produced by speculation model and did not match | ||
// - no_matched - produced by main model (and accepted) but not matched with speculation model | ||
// generated by both draft and main models | ||
static void dbg_color(const std::string& s, const std::string& fg, size_t bg_index) { | ||
static const std::vector<std::string> kBackgrounds = {"\033[40m", "\033[48m"}; | ||
static const std::string kReset = "\033[0m"; | ||
std::cout | ||
<< kBackgrounds[bg_index % kBackgrounds.size()] | ||
<< fg << s << kReset << std::flush; | ||
} | ||
|
||
static void dbg_accepted(const std::string& accepted, size_t bg_index) { | ||
static const std::string kGreen = "\033[32m"; | ||
dbg_color(accepted, kGreen, bg_index); | ||
} | ||
|
||
// something main model generated which was accepted but not matched by | ||
static void dbg_not_matched(const std::string& accepted, size_t bg_index) { | ||
dbg_color(accepted, "", bg_index); | ||
} | ||
|
||
static void dbg_rejected(const std::string& rejected, size_t bg_index) { | ||
static const std::string kRed = "\033[31m"; | ||
dbg_color(rejected, kRed, bg_index); | ||
} | ||
|
||
// shared data between main and speculation process | ||
struct linear_speculative_context { | ||
std::vector<llama_token> speculation; | ||
std::mutex mtx; | ||
bool done; | ||
}; | ||
|
||
// greedy sampling | ||
static std::vector<llama_token> greedy_tokens(llama_model* model, llama_context* ctx, int from_idx, int to_idx) { | ||
auto n_vocab = llama_n_vocab(model); | ||
std::vector<llama_token_data> candidates; | ||
candidates.resize(n_vocab); | ||
std::vector<llama_token> res; | ||
|
||
for (int idx = from_idx; idx < to_idx; idx++) { | ||
auto * logits = llama_get_logits_ith(ctx, idx); | ||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | ||
candidates[token_id] = llama_token_data{ token_id, logits[token_id], 0.0f }; | ||
} | ||
|
||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | ||
|
||
// sample the most likely token | ||
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); | ||
res.push_back(new_token_id); | ||
} | ||
return res; | ||
} | ||
|
||
static int main_loop( | ||
llama_model *model, | ||
linear_speculative_context *spec_ctx, | ||
llama_context *ctx, | ||
std::vector<llama_token> tokens_list /* copy here */) { | ||
const int n_len = 1024; | ||
|
||
llama_batch batch = llama_batch_init(1024, 0, 1); | ||
|
||
// evaluate the initial prompt | ||
for (size_t i = 0; i < tokens_list.size(); i++) { | ||
llama_batch_add(batch, tokens_list[i], i, { 0 }, false); | ||
} | ||
|
||
// llama_decode will output logits only for the last token of the prompt | ||
batch.logits[batch.n_tokens - 1] = true; | ||
|
||
if (llama_decode(ctx, batch) != 0) { | ||
LOG_TEE("%s: llama_decode() failed\n", __func__); | ||
return 1; | ||
} | ||
|
||
// how many tokens are currently accepted | ||
int n_cur = batch.n_tokens; | ||
|
||
std::vector<llama_token> input_seq, next_tokens; | ||
input_seq.push_back(tokens_list.back()); | ||
|
||
int logits_from = n_cur - 1; | ||
int logits_to = n_cur; | ||
size_t bg_index = 0; | ||
|
||
while (n_cur <= n_len) { | ||
bg_index++; | ||
next_tokens = greedy_tokens(model, ctx, logits_from, logits_to); | ||
if (next_tokens.size() != input_seq.size()) { | ||
fprintf(stderr, "invalid next tokens\n"); | ||
return 1; | ||
} | ||
|
||
// this is where next_tokens start | ||
int next_tokens_pos = n_cur; | ||
// we always accept at least one new token | ||
n_cur += 1; | ||
for (size_t i = 0; i + 1 < input_seq.size(); i++) { | ||
if (next_tokens[i] == input_seq[i + 1]) { | ||
n_cur += 1; | ||
} else { | ||
// reject. next_tokens[i] is the last 'correct' one. | ||
next_tokens.erase(next_tokens.begin() + i + 1, next_tokens.end()); | ||
break; | ||
} | ||
} | ||
// empty the main model cache | ||
llama_kv_cache_seq_rm(ctx, 0, n_cur - 1, -1); | ||
|
||
bool done = false; | ||
for (llama_token new_token_id: next_tokens) { | ||
if (new_token_id == llama_token_eos(model)) { | ||
done = true; | ||
} | ||
} | ||
if (n_cur >= n_len || done) { | ||
break; | ||
} | ||
|
||
// CRITICAL SECTION -- reconcile main and speculative | ||
{ | ||
std::lock_guard<std::mutex> _lock(spec_ctx->mtx); | ||
auto& spec = spec_ctx->speculation; | ||
size_t n_match = 0; | ||
for (size_t i = 0; i < next_tokens.size() && i + next_tokens_pos < spec.size(); i++) { | ||
if (next_tokens[i] == spec[i + next_tokens_pos]) { | ||
n_match++; | ||
} else { | ||
break; | ||
} | ||
} | ||
|
||
std::string accepted = ""; | ||
// Write accepted/rejected/not matched | ||
// this is slow and inefficient but for short strings doesn't matter | ||
for (size_t i = next_tokens_pos; i < next_tokens_pos + n_match; i++) { | ||
accepted += llama_token_to_piece(ctx, spec[i]); | ||
} | ||
dbg_accepted(accepted, bg_index); | ||
if (n_match != next_tokens.size()) { | ||
std::string rejected = ""; | ||
for (size_t i = next_tokens_pos + n_match; i < spec.size(); i++) { | ||
rejected += llama_token_to_piece(ctx, spec[i]); | ||
} | ||
dbg_rejected(rejected, bg_index); | ||
// need to modify speculation | ||
spec.erase(spec.begin() + next_tokens_pos, spec.end()); | ||
for (const auto tok: next_tokens) { | ||
spec.push_back(tok); | ||
} | ||
std::string not_matched = ""; | ||
for (size_t i = n_match; i < next_tokens.size(); i++) { | ||
not_matched += llama_token_to_piece(ctx, next_tokens[i]); | ||
} | ||
dbg_not_matched(not_matched, bg_index); | ||
} | ||
|
||
input_seq.assign(spec.begin() + n_cur - 1, spec.end()); | ||
} | ||
|
||
llama_batch_clear(batch); | ||
for (size_t i = 0; i < input_seq.size(); i++) { | ||
llama_batch_add(batch, input_seq[i], n_cur - 1 + i, { 0 }, true); | ||
} | ||
if (llama_decode(ctx, batch)) { | ||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); | ||
return 1; | ||
} | ||
logits_from = 0; | ||
logits_to = input_seq.size(); | ||
} | ||
|
||
for (size_t i = 0; i < next_tokens.size(); i++) { | ||
dbg_not_matched(llama_token_to_piece(ctx, next_tokens[i]), bg_index); | ||
} | ||
std::cout << std::endl << std::endl; | ||
{ | ||
std::lock_guard<std::mutex> _lock(spec_ctx->mtx); | ||
spec_ctx->done = true; | ||
} | ||
|
||
llama_batch_free(batch); | ||
return 0; | ||
} | ||
|
||
static int draft_loop( | ||
llama_model *model, | ||
linear_speculative_context *spec_ctx, | ||
llama_context *ctx, | ||
std::vector<llama_token> tokens_list /* copy here */) { | ||
|
||
llama_batch batch = llama_batch_init(512, 0, 1); | ||
|
||
// evaluate the initial prompt | ||
for (size_t i = 0; i < tokens_list.size(); i++) { | ||
llama_batch_add(batch, tokens_list[i], i, { 0 }, false); | ||
} | ||
|
||
// llama_decode will output logits only for the last token of the prompt | ||
batch.logits[batch.n_tokens - 1] = true; | ||
|
||
if (llama_decode(ctx, batch) != 0) { | ||
LOG_TEE("%s: llama_decode() failed\n", __func__); | ||
return 1; | ||
} | ||
|
||
int logit_idx = batch.n_tokens - 1; | ||
std::vector<llama_token> local_spec = tokens_list; | ||
size_t match_len; | ||
|
||
while (true) { | ||
auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1); | ||
if (next_tokens.size() != 1) { | ||
fprintf(stderr, "invalid next tokens\n"); | ||
return 1; | ||
} | ||
|
||
local_spec.push_back(next_tokens[0]); | ||
|
||
{ | ||
std::lock_guard<std::mutex> _lock(spec_ctx->mtx); | ||
if (spec_ctx->done) { | ||
break; | ||
} | ||
auto& spec = spec_ctx->speculation; | ||
bool match = true; | ||
match_len = local_spec.size() - 1; | ||
for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++) { | ||
if (spec[i] != local_spec[i]) { | ||
match = false; | ||
match_len = i; | ||
llama_kv_cache_seq_rm(ctx, 0, i, -1); | ||
break; | ||
} | ||
} | ||
if (match) { | ||
spec = local_spec; | ||
} else { | ||
local_spec = spec; | ||
} | ||
} | ||
|
||
llama_batch_clear(batch); | ||
// TODO theoretically this can be empty? | ||
for (size_t i = match_len; i < local_spec.size(); i++) { | ||
llama_batch_add(batch, local_spec[i], i, { 0 }, true); | ||
} | ||
|
||
logit_idx = batch.n_tokens - 1; | ||
|
||
// evaluate the current batch with the transformer model | ||
if (llama_decode(ctx, batch)) { | ||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); | ||
return 1; | ||
} | ||
} | ||
|
||
llama_batch_free(batch); | ||
return 0; | ||
} | ||
|
||
int main(int argc, char ** argv) { | ||
gpt_params params; | ||
|
||
llama_backend_init(); | ||
llama_numa_init(params.numa); | ||
|
||
// init context params | ||
llama_context_params ctx_params = llama_context_default_params(); | ||
ctx_params.seed = 1234; | ||
ctx_params.n_ctx = 2048; | ||
ctx_params.n_threads = params.n_threads; | ||
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; | ||
|
||
// Init main model and context | ||
if (argc >= 2) { | ||
params.model = argv[1]; | ||
} | ||
llama_model_params model_params = llama_model_default_params(); | ||
model_params.n_gpu_layers = 99; | ||
llama_model *main_model = llama_load_model_from_file(params.model.c_str(), model_params); | ||
llama_context *main_ctx = llama_new_context_with_model(main_model, ctx_params); | ||
|
||
// Init draft model | ||
if (argc >= 3) { | ||
params.model = argv[2]; | ||
} | ||
model_params.n_gpu_layers = 0; | ||
llama_model *draft_model = llama_load_model_from_file(params.model.c_str(), model_params); | ||
llama_context *draft_ctx = llama_new_context_with_model(draft_model, ctx_params); | ||
|
||
// Print & tokenize prompt | ||
// tokenizer should be the same and prompt tokens should be the same | ||
if (argc >= 4) { | ||
params.prompt = argv[3]; | ||
} | ||
if (params.prompt.empty()) { | ||
params.prompt = "What's the difference between instruction cache and data cache?"; | ||
} | ||
dbg_not_matched(params.prompt, 0); | ||
std::vector<llama_token> tokens_list = llama_tokenize(main_ctx, params.prompt, true); | ||
|
||
// Init shared speculative context | ||
linear_speculative_context spec_ctx; | ||
spec_ctx.speculation = tokens_list; | ||
spec_ctx.done = false; | ||
|
||
const auto t_main_start = ggml_time_us(); | ||
std::thread t_cpu(draft_loop, draft_model, &spec_ctx, draft_ctx, tokens_list); | ||
std::thread t_gpu(main_loop, main_model, &spec_ctx, main_ctx, tokens_list); | ||
t_gpu.join(); | ||
t_cpu.join(); | ||
const auto t_main_end = ggml_time_us(); | ||
|
||
printf("Total time: %.3lf\n", (t_main_end - t_main_start) / 1000000.0); | ||
|
||
llama_free_model(main_model); | ||
llama_free(main_ctx); | ||
llama_free_model(draft_model); | ||
llama_free(draft_ctx); | ||
llama_backend_free(); | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nPlease give a detailed description of concurrency and parallelism in Python. Provide some examples.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters