Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3000,6 +3000,8 @@ server_context_meta server_context::get_meta() const {
/* fim_rep_token */ llama_vocab_fim_rep(impl->vocab),
/* fim_sep_token */ llama_vocab_fim_sep(impl->vocab),

/* logit_bias_eog */ impl->params_base.sampling.logit_bias_eog,

/* model_vocab_type */ llama_vocab_type(impl->vocab),
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model),
Expand Down Expand Up @@ -3084,6 +3086,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
ctx_server.vocab,
params,
meta->slot_n_ctx,
meta->logit_bias_eog,
data);
task.id_slot = json_value(data, "id_slot", -1);

Expand Down
3 changes: 3 additions & 0 deletions tools/server/server-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ struct server_context_meta {
llama_token fim_rep_token;
llama_token fim_sep_token;

// sampling
std::vector<llama_logit_bias> logit_bias_eog;

// model meta
enum llama_vocab_type model_vocab_type;
int32_t model_vocab_n_tokens;
Expand Down
3 changes: 2 additions & 1 deletion tools/server/server-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ task_params server_task::params_from_json_cmpl(
const llama_vocab * vocab,
const common_params & params_base,
const int n_ctx_slot,
const std::vector<llama_logit_bias> & logit_bias_eog,
const json & data) {
task_params params;

Expand Down Expand Up @@ -562,7 +563,7 @@ task_params server_task::params_from_json_cmpl(
if (params.sampling.ignore_eos) {
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
logit_bias_eog.begin(), logit_bias_eog.end());
}
}

Expand Down
1 change: 1 addition & 0 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ struct server_task {
const llama_vocab * vocab,
const common_params & params_base,
const int n_ctx_slot,
const std::vector<llama_logit_bias> & logit_bias_eog,
const json & data);

// utility function
Expand Down
43 changes: 43 additions & 0 deletions tools/server/tests/unit/test_ignore_eos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from utils import *

server = ServerPreset.tinyllama2()


@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()


def test_ignore_eos_populates_logit_bias():
"""ignore_eos=true must add EOG logit biases to generation_settings."""
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "Once upon a time",
"ignore_eos": True,
"temperature": 0.0,
})
assert res.status_code == 200
# EOG token biases must be present with -inf bias
logit_bias = res.body["generation_settings"]["logit_bias"]
assert len(logit_bias) > 0
for entry in logit_bias:
assert entry["bias"] is None # null in JSON represents -inf


def test_ignore_eos_false_no_logit_bias():
"""ignore_eos=false (default) must NOT add EOG logit biases."""
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "Once upon a time",
"ignore_eos": False,
"temperature": 0.0,
})
assert res.status_code == 200
logit_bias = res.body["generation_settings"]["logit_bias"]
assert len(logit_bias) == 0
Loading