Skip to content

server : fix context shift #5195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/server/chat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ chat_completion() {
top_p: 0.9,
n_keep: $n_keep,
n_predict: 256,
cache_prompt: true,
stop: ["\n### Human:"],
stream: true
}')"
Expand Down
109 changes: 59 additions & 50 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ struct llama_client_slot
llama_sampling_context *ctx_sampling = nullptr;

int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1;// group-attention factor
int32_t ga_n = 1; // group-attention factor
int32_t ga_w = 512; // group-attention width

int32_t n_past_se = 0; // self-extend
Expand Down Expand Up @@ -219,7 +219,8 @@ struct llama_client_slot
sent_token_probs_index = 0;
infill = false;
ga_i = 0;
n_past_se = 0;
n_past_se = 0;

generated_token_probs.clear();

for (slot_image & img : images)
Expand Down Expand Up @@ -1227,7 +1228,7 @@ struct llama_server_context
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
for (int i = 0; i < (int) append_tokens.size(); ++i)
{
llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true);
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
slot.n_past += 1;
}
}
Expand Down Expand Up @@ -1295,6 +1296,8 @@ struct llama_server_context
for (llama_client_slot &slot : slots)
{
slot.cache_tokens.clear();
slot.n_past = 0;
slot.n_past_se = 0;
}
}

Expand Down Expand Up @@ -1364,26 +1367,26 @@ struct llama_server_context
kv_cache_clear();
}
return true;
} else {
task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
queue_tasks.post(task);
}

task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
queue_tasks.post(task);

for (llama_client_slot &slot : slots)
{
if (slot.ga_n == 1)
{
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx)
{
// Shift context
const int n_left = slot.n_past - slot.params.n_keep - 1;
const int n_left = system_tokens.size() + slot.n_past - slot.params.n_keep - 1;
const int n_discard = n_left / 2;

LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, system_tokens.size() + slot.n_past, -n_discard);

for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
{
Expand Down Expand Up @@ -1429,8 +1432,10 @@ struct llama_server_context
slot.i_batch = batch.n_tokens;

const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);

// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
slot.n_past += 1;
}

Expand Down Expand Up @@ -1481,8 +1486,8 @@ struct llama_server_context

prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
}
Expand Down Expand Up @@ -1582,19 +1587,22 @@ struct llama_server_context
}

LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
});

const bool has_images = process_images(slot);

// process the prefix of first image
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;

int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
int ga_i = slot.ga_i;

int32_t ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;

for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
{
if (slot.ga_n != 1)
Expand All @@ -1606,7 +1614,7 @@ struct llama_server_context
}
}
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
slot_npast += 1;
slot_npast++;
}

if (has_images && !ingest_images(slot, n_batch))
Expand Down Expand Up @@ -1666,6 +1674,7 @@ struct llama_server_context
slot.n_past_se += n_tokens;
}
}

llama_batch batch_view =
{
n_tokens,
Expand Down Expand Up @@ -1782,51 +1791,51 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
if (llama_mlock_supported())
{
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
if (llama_mmap_supported())
{
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
}
printf(" --numa attempt optimizations that help on some NUMA systems\n");
printf(" --numa attempt optimizations that help on some NUMA systems\n");
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
printf(" -ngl N, --n-gpu-layers N\n");
printf(" number of layers to store in VRAM\n");
printf(" number of layers to store in VRAM\n");
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
printf(" how to split the model across multiple GPUs, one of:\n");
printf(" - none: use one GPU only\n");
printf(" - layer (default): split layers and KV across GPUs\n");
printf(" - row: split rows across GPUs\n");
printf(" how to split the model across multiple GPUs, one of:\n");
printf(" - none: use one GPU only\n");
printf(" - layer (default): split layers and KV across GPUs\n");
printf(" - row: split rows across GPUs\n");
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row)\n");
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row)\n");
#endif
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
printf(" model path (default: %s)\n", params.model.c_str());
printf(" -a ALIAS, --alias ALIAS\n");
printf(" set an alias for the model, will be added as `model` field in completion response\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" Set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
printf(" --log-disable disables logging to a file.\n");
printf(" set an alias for the model, will be added as `model` field in completion response\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
printf(" --log-disable disables logging to a file.\n");
printf("\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -gan N, --grp-attn-n N Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
printf(" -gaw N, --grp-attn-w N Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
printf("\n");
}

Expand Down