Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perplexity : faster Winogrande via batching #5024

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
perplexity : faster Winogrande via batching
ggml-ci
  • Loading branch information
ggerganov committed Jan 18, 2024
commit 0f3ed789b32a239c034e34ee284ae3df280a1096
275 changes: 164 additions & 111 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,33 @@ static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int>
return result;
}

static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));

llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The llama_kv_cache_seq_rm call is no longer needed here?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed because we clear the entire KV cache before each batch:

llama_kv_cache_clear(ctx);

In the old implementation, it was reusing tokens from a previous batch, so the llama_kv_cache_seq_rm was used to evict the unused ones (i.e. the second sentence).

const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}

memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
}

return true;
}

static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
constexpr int k_token_chunk = 4;
Expand Down Expand Up @@ -576,7 +603,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {

// determine the common prefix of the endings
hs_cur.common_prefix = 0;
hs_cur.required_tokens = 0;
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
Expand Down Expand Up @@ -609,45 +635,18 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;

const int max_tasks_per_batch = params.n_parallel;
const int max_tasks_per_batch = 32;
const int max_seq = 4*max_tasks_per_batch;

llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_ctx*n_vocab);
std::vector<float> batch_logits(n_vocab*n_ctx);

std::vector<std::pair<size_t, llama_token>> eval_pairs;
std::vector<float> eval_results;
std::vector<std::thread> workers(std::thread::hardware_concurrency());

auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));

llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}

memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
}

return true;
};

for (size_t i0 = 0; i0 < hs_task_count; i0++) {
int n_cur = 0;

Expand Down Expand Up @@ -696,7 +695,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
llama_kv_cache_clear(ctx);

// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, n_batch)) {
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
return;
}
Expand Down Expand Up @@ -772,6 +771,13 @@ struct winogrande_entry {
std::string second;
std::array<std::string, 2> choices;
int answer;

size_t i_batch;
size_t common_prefix;
size_t required_tokens;
size_t n_base1; // number of tokens for context + choice 1
size_t n_base2; // number of tokens for context + choice 2
std::vector<llama_token> seq_tokens[2];
};

static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
Expand Down Expand Up @@ -854,6 +860,29 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
return;
}

// This is needed as usual for LLaMA models
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));

for (auto & task : data) {
task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos);
task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos);

task.common_prefix = 0;
for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
break;
}
task.common_prefix++;
}

task.required_tokens = task.common_prefix +
task.seq_tokens[0].size() - task.common_prefix +
task.seq_tokens[1].size() - task.common_prefix;

task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
}

ggerganov marked this conversation as resolved.
Show resolved Hide resolved
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, data.size());

if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) {
Expand All @@ -875,115 +904,139 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
data = std::move(selected);
}

// This is needed as usual for LLaMA models
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));

fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);

const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;

const int max_tasks_per_batch = 128;
const int max_seq = 2*max_tasks_per_batch;

llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_vocab*n_ctx);

int n_correct = 0;
int n_done = 0;

for (size_t task_idx = 0; task_idx < data.size(); task_idx++) {
const auto& task = data[task_idx];
for (size_t i0 = 0; i0 < data.size(); i0++) {
int n_cur = 0;

auto base_context = ::llama_tokenize(ctx, task.first, add_bos);
auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos);
auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos);
size_t i1 = i0;
size_t i_batch = 0;

auto sentence_1st = task.first + task.choices[0] + task.second;
auto sentence_2nd = task.first + task.choices[1] + task.second;
auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos);
auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos);
llama_batch_clear(batch);

if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size());
return;
}
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
const int s0 = 2*(i1 - i0);
if (s0 + 2 > max_seq) {
break;
}

auto query_1st_size = query_1st.size();
auto query_2nd_size = query_2nd.size();
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
}
batch.logits[batch.n_tokens - 1] = true;

for (int s = 0; s < 2; ++s) {
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
}
}

// Speedup small evaluations by evaluating atleast 32 tokens
// For Winogrande this seems to slow it down rather than speed it up.
//if (query_1st.size() < 32) query_1st.resize(32);
//if (query_2nd.size() < 32) query_2nd.resize(32);
data[i1].i_batch = i_batch;
i_batch += data[i1].required_tokens;

llama_kv_cache_clear(ctx);
auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab);
n_cur += data[i1].required_tokens;
if (++i1 == data.size()) {
break;
}
}

if (i0 == i1) {
fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
return;
}

llama_kv_cache_clear(ctx);
auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab);

if (logits_1st.empty() || logits_2nd.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
return;
}

bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx &&
query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx;

float score_1st = 0;
bool is_nan_1st = false;
const auto& base_1 = skip_choice ? base_ctx_1st : base_context;
const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0;
for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) {
std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[query_1st[j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, sentence_1st.c_str(), base_context.size());
is_nan_1st = true;
break;
for (size_t i = i0; i < i1; ++i) {
auto & task = data[i];

const bool skip_choice =
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;

float score_1st = 0;
bool is_nan_1st = false;
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
size_t li = n_base1 - 1;
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[task.seq_tokens[0][j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, (task.first + task.choices[0] + task.second).c_str(), n_base1);
is_nan_1st = true;
break;
}
score_1st += std::log(prob);
}
score_1st += std::log(prob);
}
score_1st /= (query_1st_size - base_1.size() - last_1st);

float score_2nd = 0;
bool is_nan_2nd = false;
const auto& base_2 = skip_choice ? base_ctx_2nd : base_context;
const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0;
for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) {
std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[query_2nd[j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, sentence_2nd.c_str(), base_context.size());
is_nan_2nd = true;
break;
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);

float score_2nd = 0;
bool is_nan_2nd = false;
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[task.seq_tokens[1][j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, (task.first + task.choices[1] + task.second).c_str(), n_base2);
is_nan_2nd = true;
break;
}
score_2nd += std::log(prob);
}
score_2nd += std::log(prob);
}
score_2nd /= (query_2nd_size - base_2.size() - last_2nd);
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);

if (is_nan_1st || is_nan_2nd) {
continue;
}
if (is_nan_1st || is_nan_2nd) {
continue;
}

if (std::isnan(score_1st) || std::isnan(score_2nd)) {
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size);
printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size);
printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size());
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice);
continue;
}
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
printf("Q1: <%s> - %zu tokens\n", (task.first + task.choices[0] + task.second).c_str(), task.seq_tokens[0].size());
printf("Q2: <%s> - %zu tokens\n", (task.first + task.choices[1] + task.second).c_str(), task.seq_tokens[1].size());
printf("B : <%s> - %zu tokens\n", task.first.c_str(), task.common_prefix);
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", n_base1, n_base2, skip_choice);
continue;
}

int result = score_1st > score_2nd ? 1 : 2;
int result = score_1st > score_2nd ? 1 : 2;

if (result == task.answer) {
++n_correct;
}
++n_done;

if (result == task.answer) {
++n_correct;
// Print the accumulated accuracy mean x 100
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
fflush(stdout);
}
++n_done;

// Print the accumulated accuracy mean x 100
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer);
fflush(stdout);
i0 = i1 - 1;
}

printf("\n");
Expand Down
Loading