Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port of self extension to server #5104

Merged
merged 18 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ Command line options:
- `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)
- `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load "a system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime)
- `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA.

- `--grp-attn-n`: Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`
- `--grp-attn-w`: Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`
## Build

server is build alongside everything else from the root of the project
Expand Down
164 changes: 141 additions & 23 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,12 @@ struct llama_client_slot
struct llama_sampling_params sparams;
llama_sampling_context *ctx_sampling = nullptr;

int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1;// group-attention factor
int32_t ga_w = 512; // group-attention width

int32_t n_past_se = 0; // self-extend

// multimodal
std::vector<slot_image> images;

Expand Down Expand Up @@ -212,7 +218,8 @@ struct llama_client_slot
sent_count = 0;
sent_token_probs_index = 0;
infill = false;

ga_i = 0;
n_past_se = 0;
generated_token_probs.clear();

for (slot_image & img : images)
Expand Down Expand Up @@ -399,9 +406,26 @@ struct llama_server_context

slot.id = i;
slot.n_ctx = n_ctx_slot;
slot.reset();

LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);

const int ga_n = params.grp_attn_n;
const int ga_w = params.grp_attn_w;

if (ga_n != 1) {
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w);
}

slot.ga_i = 0;
slot.ga_n = ga_n;
slot.ga_w = ga_w;

slot.reset();

slots.push_back(slot);
}

Expand Down Expand Up @@ -1349,32 +1373,35 @@ struct llama_server_context

for (llama_client_slot &slot : slots)
{
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
if (slot.ga_n == 1)
{
// Shift context
const int n_left = slot.n_past - slot.params.n_keep - 1;
const int n_discard = n_left / 2;
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
{
// Shift context
const int n_left = slot.n_past - slot.params.n_keep - 1;
const int n_discard = n_left / 2;

LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);

for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
{
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
{
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}

slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);

slot.n_past -= n_discard;
slot.n_past -= n_discard;

slot.truncated = true;
slot.truncated = true;

LOG_VERBOSE("context shift", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
});
LOG_VERBOSE("context shift", {
{ "n_ctx", n_ctx },
{ "n_keep", params.n_keep },
{ "n_left", n_left },
});
}
}
}

Expand All @@ -1401,7 +1428,8 @@ struct llama_server_context

slot.i_batch = batch.n_tokens;

llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true);
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);

slot.n_past += 1;
}
Expand Down Expand Up @@ -1499,6 +1527,8 @@ struct llama_server_context
llama_sampling_reset(slot.ctx_sampling);

slot.n_past = 0;
slot.n_past_se = 0;
slot.ga_i = 0;
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
}
else
Expand All @@ -1512,6 +1542,25 @@ struct llama_server_context
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;

if (slot.ga_n != 1)
{
int ga_i = 0;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
int32_t slot_npast = 0;
for (int k = 0; k < slot.n_past; ++k)
{
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
slot_npast -= bd;
ga_i += ga_w/ga_n;
}
slot_npast++;
}
slot.n_past_se = slot_npast;
slot.ga_i = ga_i;
}

LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
}

Expand All @@ -1526,6 +1575,10 @@ struct llama_server_context
// we have to evaluate at least 1 token to generate logits.
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
slot.n_past--;
if (slot.ga_i > 0)
{
slot.n_past_se--;
}
}

LOG_VERBOSE("prompt ingested", {
Expand All @@ -1538,9 +1591,22 @@ struct llama_server_context

// process the prefix of first image
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
int ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
{
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);
if (slot.ga_n != 1)
{
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
slot_npast -= bd;
ga_i += ga_w/ga_n;
}
}
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
slot_npast += 1;
}

if (has_images && !ingest_images(slot, n_batch))
Expand Down Expand Up @@ -1570,6 +1636,36 @@ struct llama_server_context
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
{
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));

for (auto & slot : slots)
{
if (slot.ga_n != 1)
{
// context extension via Self-Extend
while (slot.n_past_se >= slot.ga_i + slot.ga_w)
{
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;

LOG_TEE("\n");
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);

llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);

slot.n_past_se -= bd;

slot.ga_i += slot.ga_w / slot.ga_n;

LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
}
slot.n_past_se += n_tokens;
}
}
llama_batch batch_view =
{
n_tokens,
Expand All @@ -1583,6 +1679,7 @@ struct llama_server_context
};

const int ret = llama_decode(ctx, batch_view);

if (ret != 0)
{
if (n_batch == 1 || ret < 0)
Expand Down Expand Up @@ -1728,6 +1825,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -gan N, --grp-attn-n N Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
printf(" -gaw N, --grp-attn-w N Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
printf("\n");
}

Expand Down Expand Up @@ -1913,6 +2012,25 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_threads = std::stoi(argv[i]);
}
else if (arg == "--grp-attn-n" || arg == "-gan")
{
if (++i >= argc) {
invalid_param = true;
break;
}

params.grp_attn_n = std::stoi(argv[i]);
}
else if (arg == "--grp-attn-w" || arg == "-gaw")
{
if (++i >= argc)
{
invalid_param = true;
break;
}

params.grp_attn_w = std::stoi(argv[i]);
}
else if (arg == "--threads-batch" || arg == "-tb")
{
if (++i >= argc)
Expand Down
Loading