Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : fix embeddings #5796

Merged
merged 9 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale;
Expand Down
13 changes: 9 additions & 4 deletions examples/embedding/embedding.cpp
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

@cebtenzzre cebtenzzre Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the llama_kv_cache_clear still do anything useful?

edit: I remembered that this example is used for models with causal attention as well. I won't need the equivalent if I'm just working with embedding models.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, llama_kv_cache_clear is not necessary for embedding models, but makes sense for causal attention models

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ static std::vector<std::string> split_lines(const std::string & s) {

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
}
}

Expand All @@ -45,9 +45,13 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

// normalize on copy
for (int k = 0; k < n_seq; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
float * out = output + k * n_embd;
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}

float * emb = llama_get_embeddings_ith(ctx, i);
float * out = output + batch.seq_id[i][0] * n_embd;
normalize(emb, out, n_embd);
}
}
Expand Down Expand Up @@ -145,6 +149,7 @@ int main(int argc, char ** argv) {
for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens
auto & inp = inputs[k];

const uint64_t n_toks = inp.size();

// encode if at capacity
Expand Down
34 changes: 34 additions & 0 deletions examples/server-embd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import asyncio
import requests
import numpy as np

n = 8

result = []

async def requests_post_async(*args, **kwargs):
return await asyncio.to_thread(requests.post, *args, **kwargs)

async def main():
model_url = "http://127.0.0.1:6900"
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
url= f"{model_url}/embedding",
json= {"content": str(i)*32}
) for i in range(n)])

for response in responses:
embedding = response.json()["embedding"]
print(embedding[-8:])
result.append(embedding)

asyncio.run(main())

# compute cosine similarity

for i in range(n-1):
for j in range(i+1, n):
embedding1 = np.array(result[i])
embedding2 = np.array(result[j])
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
print(f"Similarity between {i} and {j}: {similarity:.2f}")

29 changes: 18 additions & 11 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ struct llama_server_context
queue_results.send(res);
}

void send_embedding(server_slot &slot)
void send_embedding(server_slot & slot, const llama_batch & batch)
{
task_result res;
res.id = slot.task_id;
Expand All @@ -1219,6 +1219,7 @@ struct llama_server_context
res.stop = true;

const int n_embd = llama_n_embd(model);

if (!params.embedding)
{
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
Expand All @@ -1229,12 +1230,19 @@ struct llama_server_context
}
else
{
const float *data = llama_get_embeddings(ctx);
std::vector<float> embedding(data, data + n_embd);
res.result_json = json
{
{"embedding", embedding},
};
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}

const float * data = llama_get_embeddings_ith(ctx, i);
std::vector<float> embedding(data, data + n_embd);

res.result_json = json
{
{"embedding", embedding },
};
}
}
queue_results.send(res);
}
Expand Down Expand Up @@ -1845,7 +1853,7 @@ struct llama_server_context
ga_i += ga_w/ga_n;
}
}
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
slot_npast++;
}

Expand Down Expand Up @@ -1881,7 +1889,7 @@ struct llama_server_context

for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
{
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);

for (auto & slot : slots)
{
Expand Down Expand Up @@ -1954,7 +1962,7 @@ struct llama_server_context
// prompt evaluated for embedding
if (slot.embedding)
{
send_embedding(slot);
send_embedding(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue;
Expand Down Expand Up @@ -2330,7 +2338,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_batch = std::stoi(argv[i]);
params.n_batch = std::min(512, params.n_batch);
}
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
{
Expand Down
Loading
Loading