Skip to content

Commit

Permalink
Slightly faster imatrix (ggerganov#5050)
Browse files Browse the repository at this point in the history
* imatrix: speedup by avoiding unnecessary allocations and copies

* imatrix: add --no-ppl option to skip PPL calculations altogether

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
  • Loading branch information
ikawrakow and Kawrakow authored Jan 21, 2024
1 parent 942c010 commit 726c0fa
Showing 1 changed file with 49 additions and 25 deletions.
74 changes: 49 additions & 25 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ static void process_logits(
}
}

static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) {

const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
Expand All @@ -269,10 +269,12 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
}

std::vector<float> logit_history;
logit_history.resize(tokens.size());

std::vector<float> prob_history;
prob_history.resize(tokens.size());

if (compute_ppl) {
logit_history.resize(tokens.size());
prob_history.resize(tokens.size());
}

const int n_chunk_max = tokens.size() / n_ctx;

Expand All @@ -288,12 +290,17 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {

std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);

const int num_batches = (n_ctx + n_batch - 1) / n_batch;

std::vector<float> logits;
if (compute_ppl && num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
}

for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
const int end = start + n_ctx;

const int num_batches = (n_ctx + n_batch - 1) / n_batch;

std::vector<float> logits;

const auto t_start = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -321,8 +328,10 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;

const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
if (compute_ppl && num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
}

const auto t_end = std::chrono::high_resolution_clock::now();
Expand All @@ -338,25 +347,32 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}

const int first = n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1;
if (compute_ppl) {
const int first = n_ctx/2;
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1;

printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
fflush(stdout);

printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
fflush(stdout);
logits.clear();
}
}
printf("\n");

nll2 /= count;
nll /= count;
const double ppl = exp(nll);
nll2 -= nll * nll;
if (nll2 > 0) {
nll2 = sqrt(nll2/(count-1));
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
} else {
printf("Unexpected negative standard deviation of log(prob)\n");
if (compute_ppl) {
nll2 /= count;
nll /= count;
const double ppl = exp(nll);
nll2 -= nll * nll;
if (nll2 > 0) {
nll2 = sqrt(nll2/(count-1));
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
} else {
printf("Unexpected negative standard deviation of log(prob)\n");
}
}

return true;
Expand All @@ -365,6 +381,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
int main(int argc, char ** argv) {

StatParams sparams;
bool compute_ppl = true;
std::vector<char*> args;
args.push_back(argv[0]);
int iarg = 1;
Expand All @@ -381,12 +398,19 @@ int main(int argc, char ** argv) {
}
else if (arg == "--verbosity") {
sparams.verbosity = std::stoi(argv[++iarg]);
} else if (arg == "--no-ppl") {
compute_ppl = false;
} else {
args.push_back(argv[iarg]);
}
}
if (iarg < argc) {
args.push_back(argv[iarg]);
std::string arg{argv[iarg]};
if (arg == "--no-ppl") {
compute_ppl = false;
} else {
args.push_back(argv[iarg]);
}
}

gpt_params params;
Expand Down Expand Up @@ -448,7 +472,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}

bool OK = compute_imatrix(ctx, params);
bool OK = compute_imatrix(ctx, params, compute_ppl);
if (!OK) {
return 1;
}
Expand Down

0 comments on commit 726c0fa

Please sign in to comment.