Skip to content

perplexity : support using multiple sequences to allow larger batch sizes #5946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 91 additions & 48 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {tokens, std::exp(nll / count), logit_history, prob_history};
}

static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
if (params.ppl_stride > 0) {
return perplexity_v2(ctx, params);
}
Expand All @@ -453,7 +453,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// BOS tokens will be added for each chunk before eval

const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);

std::ofstream logits_stream;
if (!params.logits_file.empty()) {
Expand Down Expand Up @@ -499,13 +498,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double nll2 = 0.0;

const int num_batches = (n_ctx + n_batch - 1) / n_batch;
const int n_seq = std::max(1, n_batch / n_ctx);

GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);

llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);

std::vector<float> logits;
if (num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
}

fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);

std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);

Expand All @@ -518,10 +523,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
log_probs.resize(n_ctx * nv);
}

for (int i = 0; i < n_chunk; ++i) {
// We get the logits for all the tokens in the context window (params.n_ctx)
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
// calculate the perplexity over the last half of the window (so the model always has
// some context to predict the token).
//
// We rely on the fact that attention in the forward pass only looks at previous
// tokens here, so the logits returned for each token are an accurate representation
// of what the model would have predicted at that point.
//
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
const int first = n_ctx/2;

for (int i = 0; i < n_chunk; i += n_seq) {
const int start = i * n_ctx;
const int end = start + n_ctx;

const int n_seq_batch = std::min(n_seq, n_chunk - i);

const auto t_start = std::chrono::high_resolution_clock::now();

// clear the KV cache
Expand All @@ -531,22 +552,37 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);

// save original token and restore it after eval
const auto token_org = tokens[batch_start];
batch.n_tokens = 0;
for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx;

// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
// save original token and restore it after eval
const auto token_org = tokens[seq_start];

// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
}

for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k;
batch.token[idx] = tokens[seq_start + k];
batch.pos[idx] = j*n_batch + k;
batch.n_seq_id[idx] = 1;
batch.seq_id[idx][0] = seq;
batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0;
}
batch.n_tokens += batch_size;

// restore the original token in case it was set to BOS
tokens[seq_start] = token_org;
}

if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}

// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;

if (num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
Expand All @@ -558,45 +594,39 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
int total_seconds = (int)(t_total*n_chunk/n_seq);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}

// We get the logits for all the tokens in the context window (params.n_ctx)
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
// calculate the perplexity over the last half of the window (so the model always has
// some context to predict the token).
//
// We rely on the fact that attention in the forward pass only looks at previous
// tokens here, so the logits returned for each token are an accurate representation
// of what the model would have predicted at that point.
//
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
const int first = n_ctx/2;
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
if (!params.logits_file.empty()) {
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, log_probs, nll, nll2);
} else {
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
}
count += n_ctx - first - 1;

// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
} else {
double av = nll/count;
double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1));
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
if (!params.logits_file.empty()) {
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
tokens_data, n_ctx - 1 - first,
workers, log_probs, nll, nll2);
} else {
process_logits(n_vocab, all_logits + first*n_vocab,
tokens_data, n_ctx - 1 - first,
workers, nll, nll2,
logit_history.data() + start + seq*n_ctx + first,
prob_history.data() + start + seq*n_ctx + first);
}
count += n_ctx - first - 1;

// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
} else {
double av = nll/count;
double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1));
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
}
}
fflush(stdout);

Expand All @@ -615,6 +645,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
printf("Unexpected negative standard deviation of log(prob)\n");
}

llama_batch_free(batch);

return {tokens, ppl, logit_history, prob_history};
}

Expand Down Expand Up @@ -1782,13 +1814,24 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
int main(int argc, char ** argv) {
gpt_params params;

params.n_batch = 512;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}

params.logits_all = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);

const int32_t n_ctx = params.n_ctx;

const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
if (ppl) {
int n_seq = std::max(1, params.n_batch / n_ctx);
int32_t n_kv = n_seq * n_ctx;
params.n_parallel = n_seq;
params.n_ctx = n_kv;
params.n_batch = std::min(params.n_batch, n_kv);
} else {
params.n_batch = std::min(params.n_batch, params.n_ctx);
}

if (params.ppl_stride > 0) {
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
Expand Down Expand Up @@ -1847,7 +1890,7 @@ int main(int argc, char ** argv) {
} else if (params.kl_divergence) {
kl_divergence(ctx, params);
} else {
results = perplexity(ctx, params);
results = perplexity(ctx, params, n_ctx);
}

llama_print_timings(ctx);
Expand Down
22 changes: 17 additions & 5 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8925,17 +8925,29 @@ static int llama_decode_internal(

if (batch.logits) {
logits_out.resize(n_vocab * n_tokens);
int32_t i_first = -1;
for (uint32_t i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) {
continue;
if (batch.logits[i] && i_first == -1) {
i_first = (int32_t) i;
}
if (batch.logits[i] == 0 || i == n_tokens - 1) {
if (i_first != -1) {
int i_last = batch.logits[i] == 0 ? i : i + 1;
// extract logits for the range [i_first, i_last)
// group the requests to minimize the number of calls to the backend
ggml_backend_tensor_get_async(backend_res, res,
logits_out.data() + (n_vocab*i_first),
(n_vocab*i_first)*sizeof(float),
(i_last - i_first)*n_vocab*sizeof(float));
i_first = -1;
}
}
ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
#ifndef NDEBUG
logits_valid[i] = true;
logits_valid[i] = batch.logits[i] != 0;
#endif
}
} else if (lctx.logits_all) {
logits_out.resize(n_vocab * n_tokens);
logits_out.resize(n_vocab*n_tokens);
ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
#ifndef NDEBUG
std::fill(logits_valid.begin(), logits_valid.end(), true);
Expand Down