Skip to content

Commit

Permalink
changed token functions to use new model variants
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Oct 23, 2023
1 parent 2df3801 commit fc5bb85
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 33 deletions.
6 changes: 3 additions & 3 deletions common/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ int64_t get_example_targets_batch(
int64_t used_samples = 0;

ggml_set_f32(target_probs, 0.0f);
llama_token bos = llama_token_bos(lctx);
llama_token eos = llama_token_eos(lctx);
llama_token bos = llama_token_bos(llama_get_model(lctx));
llama_token eos = llama_token_eos(llama_get_model(lctx));
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
for (int k=0; k<n_batch; ++k) {
// printf("%s: batch %d\n", __func__, k);
Expand Down Expand Up @@ -924,7 +924,7 @@ size_t tokenize_file(
for (llama_token token=0; token < n_vocab; ++token) {
max_token_text_size = std::max(
max_token_text_size,
strlen(llama_token_get_text(lctx, token)));
strlen(llama_token_get_text(llama_get_model(lctx), token)));
}

// upper bound of context byte length.
Expand Down
2 changes: 1 addition & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);

// is it an end of stream? -> mark the stream as finished
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
i_batch[i] = -1;
LOG_TEE("\n");
if (n_parallel > 1) {
Expand Down
2 changes: 1 addition & 1 deletion examples/beam-search/beam-search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct beam_search_callback_data {
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
// For example, eob can be flagged due to maximum token length, stop words, etc.
static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) {
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
return n_tokens && tokens[n_tokens-1] == llama_token_eos(llama_get_model(callback_data.ctx));
}

// Function matching type llama_beam_search_callback_fn_t.
Expand Down
30 changes: 15 additions & 15 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,22 +246,22 @@ int main(int argc, char ** argv) {
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
inp_sfx.erase(inp_sfx.begin());
}
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
if (add_bos) {
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
}
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp.push_back(llama_token_middle(ctx));
embd_inp.push_back(llama_token_middle(model));

LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());

// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
embd_inp.push_back(llama_token_bos(model));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}

Expand Down Expand Up @@ -577,10 +577,10 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) {

// deal with eot token in infill mode
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
}
fflush(stdout);
printf("\n");
Expand Down Expand Up @@ -627,14 +627,14 @@ int main(int argc, char ** argv) {
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
inp_sfx.erase(inp_sfx.begin());
}
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
if (add_bos) {
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
}
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp.push_back(llama_token_middle(ctx));
embd_inp.push_back(llama_token_middle(model));
embd.clear();
embd_guidance.clear();
n_remain = params.n_predict;
Expand All @@ -644,7 +644,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of text token in interactive mode
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
LOG("found EOS token\n");

if (params.interactive) {
Expand All @@ -661,7 +661,7 @@ int main(int argc, char ** argv) {

if (params.input_prefix_bos) {
LOG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(ctx));
embd_inp.push_back(llama_token_bos(model));
}

std::string buffer;
Expand Down Expand Up @@ -724,7 +724,7 @@ int main(int argc, char ** argv) {
}

// end of text token
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !params.interactive) {
if (!embd.empty() && embd.back() == llama_token_eos(model) && !params.interactive) {
break;
}

Expand All @@ -736,7 +736,7 @@ int main(int argc, char ** argv) {
}
}
if (!params.interactive && n_remain <= 0) {
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
fflush(stdout);
}

Expand Down
4 changes: 2 additions & 2 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ struct sql_printer : public printer {
};

static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
int n_processed = 0;

llama_set_n_threads(ctx, n_threads, n_threads);
Expand All @@ -946,7 +946,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
}

static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos(ctx);
llama_token token = llama_token_bos(llama_get_model(ctx));

llama_set_n_threads(ctx, n_threads, n_threads);

Expand Down
2 changes: 1 addition & 1 deletion examples/llava/llava-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
inline const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
int id = sample_id(ctx_llama, params);
static std::string ret;
if (id == llama_token_eos(ctx_llama)) {
if (id == llama_token_eos(llama_get_model(ctx_llama))) {
ret = "</s>";
} else {
ret = llama_token_to_piece(ctx_llama, id);
Expand Down
2 changes: 1 addition & 1 deletion examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ int main(int argc, char ** argv) {
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());

if (client.n_decoded > 2 &&
(id == llama_token_eos(ctx) ||
(id == llama_token_eos(model) ||
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
client.response.find("User:") != std::string::npos ||
client.response.find('\n') != std::string::npos)) {
Expand Down
14 changes: 7 additions & 7 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ struct llama_server_context

if (json_value(data, "ignore_eos", false))
{
slot->sparams.logit_bias[llama_token_eos(ctx)] = -INFINITY;
slot->sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
}

const auto &logit_bias = data.find("logit_bias");
Expand Down Expand Up @@ -1056,7 +1056,7 @@ struct llama_server_context
slot.has_next_token = false;
}

if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx))
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
{
slot.stopped_eos = true;
slot.has_next_token = false;
Expand Down Expand Up @@ -1130,7 +1130,7 @@ struct llama_server_context

json get_formated_generation(llama_client_slot &slot)
{
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(ctx));
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json {
Expand Down Expand Up @@ -1555,11 +1555,11 @@ struct llama_server_context
suffix_tokens.erase(suffix_tokens.begin());
}

prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(ctx));
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
}
else
Expand Down
2 changes: 1 addition & 1 deletion examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);

// is it an end of stream?
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
LOG_TEE("\n");

break;
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ int main(int argc, char ** argv) {
printf("%s", token_str.c_str());
fflush(stdout);

if (id == llama_token_eos(ctx_tgt)) {
if (id == llama_token_eos(model_tgt)) {
has_eos = true;
}

Expand Down

0 comments on commit fc5bb85

Please sign in to comment.