Skip to content

Commit

Permalink
llama : distinguish token vs sequence embeddings
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Mar 4, 2024
1 parent e66da35 commit 79e4eed
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 56 deletions.
15 changes: 12 additions & 3 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}

static void normalize(float * vec, float * out, int n) {
static void normalize(const float * vec, float * out, int n) {
float norm = 0;
for (int i = 0; i < n; i++) {
norm += vec[i] * vec[i];
Expand All @@ -50,9 +50,18 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
continue;
}

float * emb = llama_get_embeddings_ith(ctx, i);
// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}

float * out = output + batch.seq_id[i][0] * n_embd;
normalize(emb, out, n_embd);
normalize(embd, out, n_embd);
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/server-embd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def main():
model_url = "http://127.0.0.1:6900"
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
url= f"{model_url}/embedding",
json= {"content": str(0)*32}
json= {"content": str(i)*1024}
) for i in range(n)])

for response in responses:
Expand Down
16 changes: 13 additions & 3 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1235,12 +1235,22 @@ struct llama_server_context
continue;
}

const float * data = llama_get_embeddings_ith(ctx, i);
std::vector<float> embedding(data, data + n_embd);
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
res.result_json = json
{
{"embedding", std::vector<float>(n_embd, 0.0f)},
};
continue;
}
}

res.result_json = json
{
{"embedding", embedding },
{"embedding", std::vector<float>(embd, embd + n_embd)},
};
}
}
Expand Down
141 changes: 94 additions & 47 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1983,7 +1983,12 @@ struct llama_context {
bool logits_all = false;

// embeddings output (2-dimensional array: [n_tokens][n_embd])
std::vector<float> embeddings;
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
std::vector<float> embd;

// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;

// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
Expand Down Expand Up @@ -6243,12 +6248,23 @@ struct llm_build_context {
cur = inpL;

// pooling layer
if (pooling_type == LLAMA_POOLING_TYPE_MEAN) {
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
} else if (pooling_type == LLAMA_POOLING_TYPE_CLS) {
cur = ggml_get_rows(ctx0, cur, inp_cls);
} else {
GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type");
switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// nop
} break;
case LLAMA_POOLING_TYPE_MEAN:
{
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
} break;
case LLAMA_POOLING_TYPE_CLS:
{
cur = ggml_get_rows(ctx0, cur, inp_cls);
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "Invalid pooling type");
} break;
}
cb(cur, "result_embd", -1);

Expand Down Expand Up @@ -8259,17 +8275,23 @@ static int llama_decode_internal(
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];

if (strcmp(res->name, "result_output") == 0) {
// the embeddings could be the second to last tensor, or the third to last tensor
if (strcmp(embd->name, "result_norm") != 0) {
embd = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
}
} else if (strcmp(res->name, "result_embd") == 0) {
embd = res;
res = nullptr;
if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT

// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];

GGML_ASSERT(strcmp(embd->name, "result_embd") == 0);
} else {
GGML_ASSERT(false);
if (strcmp(res->name, "result_output") == 0) {
// the token embeddings could be the second to last tensor, or the third to last tensor
if (strcmp(embd->name, "result_norm") != 0) {
embd = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
}
} else {
GGML_ASSERT(false && "missing result_output tensor");
}
}

// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
Expand Down Expand Up @@ -8368,30 +8390,46 @@ static int llama_decode_internal(

// extract embeddings
if (cparams.embeddings && embd) {
auto & embeddings_out = lctx.embeddings;

ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);

if (batch.logits) {
embeddings_out.resize(n_embd * n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) {
continue;
}
switch (hparams.pooling_type) {
case LLAMA_POOLING_TYPE_CLS:
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
break;
default:
GGML_ASSERT(false && "unknown pooling type");
break;
}
}
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
auto & embd_out = lctx.embd;

if (batch.logits) {
embd_out.resize(n_embd * n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) {
continue;
}

ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
}
}
} break;
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN:
{
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();

for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "unknown pooling type");
} break;
}
ggml_backend_synchronize(backend_embd);
}
Expand Down Expand Up @@ -12273,7 +12311,7 @@ struct llama_context * llama_new_context_with_model(
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);

if (params.embeddings) {
ctx->embeddings.reserve(hparams.n_embd*cparams.n_batch);
ctx->embd.reserve(hparams.n_embd*cparams.n_batch);
}

// graph inputs
Expand Down Expand Up @@ -12708,7 +12746,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
// assume worst case for logits although only currently set ones are serialized
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
const size_t s_embedding_size = sizeof(size_t);
const size_t s_embedding = ctx->embeddings.capacity() * sizeof(float);
const size_t s_embedding = ctx->embd.capacity() * sizeof(float);
const size_t s_kv_buf_size = sizeof(size_t);
const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t);
Expand Down Expand Up @@ -12817,12 +12855,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat

// copy embeddings
{
const size_t embeddings_size = ctx->embeddings.size();
const size_t embeddings_size = ctx->embd.size();

data_ctx->write(&embeddings_size, sizeof(embeddings_size));

if (embeddings_size) {
data_ctx->write(ctx->embeddings.data(), embeddings_size * sizeof(float));
data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float));
}
}

Expand Down Expand Up @@ -12930,12 +12968,12 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {

memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);

GGML_ASSERT(ctx->embeddings.capacity() == embeddings_size);
GGML_ASSERT(ctx->embd.capacity() == embeddings_size);

if (embeddings_size) {
ctx->embeddings.resize(embeddings_size);
ctx->embd.resize(embeddings_size);

memcpy(ctx->embeddings.data(), inp, embeddings_size * sizeof(float));
memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float));
inp += embeddings_size * sizeof(float);
}
}
Expand Down Expand Up @@ -13186,11 +13224,20 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
}

float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embeddings.data();
return ctx->embd.data();
}

float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
return ctx->embeddings.data() + i*ctx->model.hparams.n_embd;
return ctx->embd.data() + i*ctx->model.hparams.n_embd;
}

float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
auto it = ctx->embd_seq.find(seq_id);
if (it == ctx->embd_seq.end()) {
return nullptr;
}

return it->second.data();
}

const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
Expand Down
10 changes: 8 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -655,14 +655,20 @@ extern "C" {
// llama_get_logits(ctx) + i*n_vocab
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);

// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
// Get all output token embeddings
// shape: [n_tokens*n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);

// Get the embeddings for the ith token
// llama_get_embeddings(ctx) + i*n_embd
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);

// Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);

//
// Vocab
//
Expand Down

0 comments on commit 79e4eed

Please sign in to comment.