Skip to content

Commit

Permalink
server : allow using LoRA adapters per-request (#10994)
Browse files Browse the repository at this point in the history
* slot.can_batch_with

* lora per request

* test: force disable cache prompt

* move can_batch_with check

* fix condition

* add slow test with llama 8b

* update docs

* move lora change task to queue

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* lora_base

* remove redundant check

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
ngxson and ggerganov authored Jan 2, 2025
1 parent a45433b commit 0da5d86
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 59 deletions.
6 changes: 6 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ These words will not be included in the completion, so make sure to add them to

`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name.

`lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation.

**Response format**

- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
Expand Down Expand Up @@ -945,6 +947,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo

By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`

Please note that this value will be overwritten by the `lora` field for each request.

If an adapter is disabled, the scale will be set to 0.

**Response format**
Expand All @@ -966,6 +970,8 @@ If an adapter is disabled, the scale will be set to 0.

### POST `/lora-adapters`: Set list of LoRA adapters

This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request.

To disable an adapter, either remove it from the list below, or set scale to 0.

**Request format**
Expand Down
116 changes: 76 additions & 40 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ struct slot_params {
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit

std::vector<common_lora_adapter_container> lora;

std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
bool timings_per_token = false;
Expand All @@ -120,6 +122,11 @@ struct slot_params {
samplers.emplace_back(common_sampler_type_to_str(sampler));
}

json lora = json::array();
for (size_t i = 0; i < this->lora.size(); ++i) {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}

return json {
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
Expand Down Expand Up @@ -160,6 +167,7 @@ struct slot_params {
{"speculative.p_min", speculative.p_min},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"lora", lora},
};
}
};
Expand Down Expand Up @@ -189,12 +197,16 @@ struct server_task {
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;

// used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_lora_adapter_container> set_lora;

server_task(server_task_type type) : type(type) {}

static slot_params params_from_json_cmpl(
const llama_model * model,
const llama_context * ctx,
const common_params & params_base,
const std::vector<common_lora_adapter_container> & lora_base,
const json & data) {
slot_params params;

Expand Down Expand Up @@ -251,6 +263,16 @@ struct server_task {
params.speculative.n_min = std::max(params.speculative.n_min, 2);
params.speculative.n_max = std::max(params.speculative.n_max, 0);

if (data.contains("lora")) {
if (data.at("lora").is_array()) {
params.lora = parse_lora_request(lora_base, data.at("lora"));
} else {
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
}
} else {
params.lora = lora_base;
}

// TODO: add more sanity checks for the input parameters

if (params.sampling.penalty_last_n < -1) {
Expand Down Expand Up @@ -1110,6 +1132,8 @@ struct server_slot {

common_speculative * spec = nullptr;

std::vector<common_lora_adapter_container> lora;

// the index relative to completion multi-task request
size_t index = 0;

Expand Down Expand Up @@ -1191,6 +1215,11 @@ struct server_slot {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
}

bool can_batch_with(server_slot & other_slot) {
return is_non_causal() == other_slot.is_non_causal()
&& are_lora_equal(lora, other_slot.lora);
}

bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
Expand Down Expand Up @@ -1600,7 +1629,7 @@ struct server_context {

llama_model * model = nullptr;
llama_context * ctx = nullptr;
std::vector<common_lora_adapter_container> loras;
std::vector<common_lora_adapter_container> lora;

llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
Expand Down Expand Up @@ -1667,7 +1696,7 @@ struct server_context {

model = llama_init.model;
ctx = llama_init.context;
loras = llama_init.lora_adapters;
lora = llama_init.lora_adapters;

if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
Expand Down Expand Up @@ -1866,6 +1895,12 @@ struct server_context {
slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);

if (!are_lora_equal(task.params.lora, slot.lora)) {
// if lora is changed, we cannot reuse cached tokens
slot.cache_tokens.clear();
slot.lora = std::move(task.params.lora);
}

SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());

if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
Expand Down Expand Up @@ -2557,7 +2592,7 @@ struct server_context {
} break;
case SERVER_TASK_TYPE_SET_LORA:
{
common_lora_adapters_apply(ctx, loras);
lora = std::move(task.set_lora);
auto res = std::make_unique<server_task_result_apply_lora>();
res->id = task.id;
queue_results.send(std::move(res));
Expand Down Expand Up @@ -2634,12 +2669,22 @@ struct server_context {
// start populating the batch for this iteration
common_batch_clear(batch);

// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;

// frist, add sampled tokens from any ongoing sequences
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) {
continue;
}

// check if we can batch this slot with the previous one
if (!slot_batched) {
slot_batched = &slot;
} else if (!slot_batched->can_batch_with(slot)) {
continue;
}

slot.i_batch = batch.n_tokens;

common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
Expand All @@ -2658,15 +2703,18 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);

// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
// TODO: make enum
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;

// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
if (!slot_batched) {
slot_batched = &slot;
} else if (!slot_batched->can_batch_with(slot)) {
continue;
}
}

// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
auto & prompt_tokens = slot.prompt_tokens;
Expand Down Expand Up @@ -2827,14 +2875,6 @@ struct server_context {
}
}

// check that we are in the right batch_type, if not defer the slot
int slot_type = slot.is_non_causal();
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
continue;
}

// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
Expand Down Expand Up @@ -2902,8 +2942,12 @@ struct server_context {

SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);

// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);
if (slot_batched) {
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, slot_batched->is_non_causal());
// apply lora, only need to do it once per batch
common_lora_adapters_apply(ctx, slot_batched->lora);
}

// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
Expand Down Expand Up @@ -3623,7 +3667,12 @@ int main(int argc, char ** argv) {
task.index = i;

task.prompt_tokens = std::move(tokenized_prompts[i]);
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
task.params = server_task::params_from_json_cmpl(
ctx_server.model,
ctx_server.ctx,
ctx_server.params_base,
ctx_server.lora,
data);
task.id_selected_slot = json_value(data, "id_slot", -1);

// OAI-compat
Expand Down Expand Up @@ -4049,8 +4098,8 @@ int main(int argc, char ** argv) {

const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
json result = json::array();
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
auto & lora = ctx_server.loras[i];
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
auto & lora = ctx_server.lora[i];
result.push_back({
{"id", i},
{"path", lora.path},
Expand All @@ -4062,27 +4111,14 @@ int main(int argc, char ** argv) {
};

const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
const std::vector<json> body = json::parse(req.body);
int max_idx = ctx_server.loras.size();

// clear existing value
for (auto & lora : ctx_server.loras) {
lora.scale = 0.0f;
}

// set value
for (auto entry : body) {
int id = entry.at("id");
float scale = entry.at("scale");
if (0 <= id && id < max_idx) {
ctx_server.loras[id].scale = scale;
} else {
throw std::runtime_error("invalid adapter id");
}
const json body = json::parse(req.body);
if (!body.is_array()) {
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
return;
}

server_task task(SERVER_TASK_TYPE_SET_LORA);
task.id = ctx_server.queue_tasks.get_new_id();
task.set_lora = parse_lora_request(ctx_server.lora, body);
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);

Expand Down
6 changes: 6 additions & 0 deletions examples/server/tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
DEBUG=1 ./tests.sh -s -v -x
```

To run single test unit:

```shell
./tests.sh unit/test_{name of test case here}.py -v -x
```
Hint: You can compile and run test in single command, useful for local developement:
```shell
Expand Down
1 change: 1 addition & 0 deletions examples/server/tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ numpy~=1.26.4
openai~=1.55.3
prometheus-client~=0.20.0
requests~=2.32.3
wget~=3.2
Loading

0 comments on commit 0da5d86

Please sign in to comment.