Skip to content

llama : auto-batch preparation #13845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ int main(int argc, char ** argv) {
return 1;
}

LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);

n_cache_miss += 1;

Expand Down
105 changes: 58 additions & 47 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,28 +424,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
return kv_self;
}

void llama_context::kv_self_update() {
bool llama_context::kv_self_update() {
if (!memory) {
return;
return false;
}

llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());

if (kv_self->update(*this)) {
// if the KV cache did any computation, we have to reserve a new worst-case graph
const auto kv_state = kv_self->init_full();
if (!kv_state) {
throw std::runtime_error("failed to initialize KV cache");
}
if (!kv_self->update(*this)) {
// no updates have been performed
return false;
}

const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// if the KV cache did any computation, we have to reserve a new worst-case graph
const auto kv_state = kv_self->init_full();
if (!kv_state) {
throw std::runtime_error("failed to initialize KV cache");
}

auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
}
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
}

return true;
}

enum llama_pooling_type llama_context::pooling_type() const {
Expand Down Expand Up @@ -933,24 +938,44 @@ int llama_context::decode(llama_batch & inp_batch) {
// handle any pending defrags/shifts
kv_self_update();

auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
if (!kv_state) {
return -2;
}
llama_memory_state_ptr kv_state;

switch (kv_state->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
{
// not a fatal error, we can re-try with a different batch
return 1;
}
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
return -2;
}
bool did_defrag = false;

while (true) {
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
if (!kv_state) {
return -2;
}

switch (kv_state->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
{
if (!did_defrag) {
did_defrag = true;

kv_self->defrag_sched(-1.0f);
if (kv_self_update()) {
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);

continue;
}
}

LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);

return 1;
}
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
return -2;
}
}

break;
}

// reserve output buffer
Expand Down Expand Up @@ -2646,22 +2671,8 @@ int32_t llama_encode(
int32_t llama_decode(
llama_context * ctx,
llama_batch batch) {
int ret = ctx->decode(batch);

// defrag and try again
// TODO: distinguish return code when we are sure that even after defrag there is no space available
if (ret == 1) {
llama_kv_self_defrag(ctx);
ret = ctx->decode(batch);

if (ret == 1) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);

return ret;
}
}

if (ret != 0) {
const int ret = ctx->decode(batch);
if (ret != 0 && ret != 1) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}

Expand Down
3 changes: 2 additions & 1 deletion src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ struct llama_context {
llama_kv_cache * get_kv_self();
const llama_kv_cache * get_kv_self() const;

// return true of the KV cache was updated
// TODO: remove
void kv_self_update();
bool kv_self_update();

enum llama_pooling_type pooling_type() const;

Expand Down
5 changes: 3 additions & 2 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1809,9 +1809,10 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
GGML_UNUSED(embd_pooled);

auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
// TODO: if we fail with split_simple, we should attempt different splitting strategies
// but to do that properly, we first have to refactor the batches to be more flexible

// TODO: if we fail with split_simple, we should attempt split_equal
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);

std::vector<llama_ubatch> ubatches;

Expand Down
2 changes: 1 addition & 1 deletion tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3431,7 +3431,7 @@ struct server_context {
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;

SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);

continue; // continue loop of n_batch
}
Expand Down
Loading