From 79e4eede2374870e915af4fe03a291a702879e52 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Mar 2024 19:14:22 +0200 Subject: [PATCH] llama : distinguish token vs sequence embeddings ggml-ci --- examples/embedding/embedding.cpp | 15 +++- examples/server-embd.py | 2 +- examples/server/server.cpp | 16 +++- llama.cpp | 141 ++++++++++++++++++++----------- llama.h | 10 ++- 5 files changed, 128 insertions(+), 56 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index bc5ae15a96856..ff5883da6ba27 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -23,7 +23,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } } -static void normalize(float * vec, float * out, int n) { +static void normalize(const float * vec, float * out, int n) { float norm = 0; for (int i = 0; i < n; i++) { norm += vec[i] * vec[i]; @@ -50,9 +50,18 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu continue; } - float * emb = llama_get_embeddings_ith(ctx, i); + // try to get sequence embeddings - supported only when pooling_type is not NONE + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + if (embd == NULL) { + fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i); + continue; + } + } + float * out = output + batch.seq_id[i][0] * n_embd; - normalize(emb, out, n_embd); + normalize(embd, out, n_embd); } } diff --git a/examples/server-embd.py b/examples/server-embd.py index 7ed7a17d10405..c5c4ea87b09fc 100644 --- a/examples/server-embd.py +++ b/examples/server-embd.py @@ -13,7 +13,7 @@ async def main(): model_url = "http://127.0.0.1:6900" responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( url= f"{model_url}/embedding", - json= {"content": str(0)*32} + json= {"content": str(i)*1024} ) for i in range(n)]) for response in responses: diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 56eeee09304b5..9a09de68dbfff 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1235,12 +1235,22 @@ struct llama_server_context continue; } - const float * data = llama_get_embeddings_ith(ctx, i); - std::vector embedding(data, data + n_embd); + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + if (embd == NULL) { + LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}}); + res.result_json = json + { + {"embedding", std::vector(n_embd, 0.0f)}, + }; + continue; + } + } res.result_json = json { - {"embedding", embedding }, + {"embedding", std::vector(embd, embd + n_embd)}, }; } } diff --git a/llama.cpp b/llama.cpp index 6245af22149e5..e966faeeda90f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1983,7 +1983,12 @@ struct llama_context { bool logits_all = false; // embeddings output (2-dimensional array: [n_tokens][n_embd]) - std::vector embeddings; + // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE + std::vector embd; + + // sequence embeddings output (map of [n_embd] vectors) + // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE + std::map> embd_seq; // memory buffers used to evaluate the model std::vector buf_compute_meta; @@ -6243,12 +6248,23 @@ struct llm_build_context { cur = inpL; // pooling layer - if (pooling_type == LLAMA_POOLING_TYPE_MEAN) { - cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); - } else if (pooling_type == LLAMA_POOLING_TYPE_CLS) { - cur = ggml_get_rows(ctx0, cur, inp_cls); - } else { - GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type"); + switch (pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // nop + } break; + case LLAMA_POOLING_TYPE_MEAN: + { + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); + } break; + case LLAMA_POOLING_TYPE_CLS: + { + cur = ggml_get_rows(ctx0, cur, inp_cls); + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ASSERT(false && "Invalid pooling type"); + } break; } cb(cur, "result_embd", -1); @@ -8259,17 +8275,23 @@ static int llama_decode_internal( struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; - if (strcmp(res->name, "result_output") == 0) { - // the embeddings could be the second to last tensor, or the third to last tensor - if (strcmp(embd->name, "result_norm") != 0) { - embd = gf->nodes[gf->n_nodes - 3]; - GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); - } - } else if (strcmp(res->name, "result_embd") == 0) { - embd = res; - res = nullptr; + if (!hparams.causal_attn) { + res = nullptr; // do not extract logits for embedding models such as BERT + + // token or sequence embeddings + embd = gf->nodes[gf->n_nodes - 1]; + + GGML_ASSERT(strcmp(embd->name, "result_embd") == 0); } else { - GGML_ASSERT(false); + if (strcmp(res->name, "result_output") == 0) { + // the token embeddings could be the second to last tensor, or the third to last tensor + if (strcmp(embd->name, "result_norm") != 0) { + embd = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); + } + } else { + GGML_ASSERT(false && "missing result_output tensor"); + } } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -8368,30 +8390,46 @@ static int llama_decode_internal( // extract embeddings if (cparams.embeddings && embd) { - auto & embeddings_out = lctx.embeddings; - ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd); GGML_ASSERT(backend_embd != nullptr); - if (batch.logits) { - embeddings_out.resize(n_embd * n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - if (batch.logits[i] == 0) { - continue; - } - switch (hparams.pooling_type) { - case LLAMA_POOLING_TYPE_CLS: - ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float)); - break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_NONE: - ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float)); - break; - default: - GGML_ASSERT(false && "unknown pooling type"); - break; - } - } + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + auto & embd_out = lctx.embd; + + if (batch.logits) { + embd_out.resize(n_embd * n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + + ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float)); + } + } + } break; + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_MEAN: + { + // extract sequence embeddings + auto & embd_seq_out = lctx.embd_seq; + embd_seq_out.clear(); + + for (uint32_t i = 0; i < n_tokens; i++) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ASSERT(false && "unknown pooling type"); + } break; } ggml_backend_synchronize(backend_embd); } @@ -12273,7 +12311,7 @@ struct llama_context * llama_new_context_with_model( ctx->logits.reserve(hparams.n_vocab*cparams.n_batch); if (params.embeddings) { - ctx->embeddings.reserve(hparams.n_embd*cparams.n_batch); + ctx->embd.reserve(hparams.n_embd*cparams.n_batch); } // graph inputs @@ -12708,7 +12746,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) { // assume worst case for logits although only currently set ones are serialized const size_t s_logits = ctx->logits.capacity() * sizeof(float); const size_t s_embedding_size = sizeof(size_t); - const size_t s_embedding = ctx->embeddings.capacity() * sizeof(float); + const size_t s_embedding = ctx->embd.capacity() * sizeof(float); const size_t s_kv_buf_size = sizeof(size_t); const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); @@ -12817,12 +12855,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat // copy embeddings { - const size_t embeddings_size = ctx->embeddings.size(); + const size_t embeddings_size = ctx->embd.size(); data_ctx->write(&embeddings_size, sizeof(embeddings_size)); if (embeddings_size) { - data_ctx->write(ctx->embeddings.data(), embeddings_size * sizeof(float)); + data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float)); } } @@ -12930,12 +12968,12 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size); - GGML_ASSERT(ctx->embeddings.capacity() == embeddings_size); + GGML_ASSERT(ctx->embd.capacity() == embeddings_size); if (embeddings_size) { - ctx->embeddings.resize(embeddings_size); + ctx->embd.resize(embeddings_size); - memcpy(ctx->embeddings.data(), inp, embeddings_size * sizeof(float)); + memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float)); inp += embeddings_size * sizeof(float); } } @@ -13186,11 +13224,20 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { } float * llama_get_embeddings(struct llama_context * ctx) { - return ctx->embeddings.data(); + return ctx->embd.data(); } float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { - return ctx->embeddings.data() + i*ctx->model.hparams.n_embd; + return ctx->embd.data() + i*ctx->model.hparams.n_embd; +} + +float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { + auto it = ctx->embd_seq.find(seq_id); + if (it == ctx->embd_seq.end()) { + return nullptr; + } + + return it->second.data(); } const char * llama_token_get_text(const struct llama_model * model, llama_token token) { diff --git a/llama.h b/llama.h index fed91493111ad..3dc162b078d30 100644 --- a/llama.h +++ b/llama.h @@ -655,14 +655,20 @@ extern "C" { // llama_get_logits(ctx) + i*n_vocab LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); - // Get the embeddings for the input - // shape: [n_embd] (1-dimensional) + // Get all output token embeddings + // shape: [n_tokens*n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); // Get the embeddings for the ith token // llama_get_embeddings(ctx) + i*n_embd + // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); + // Get the embeddings for a sequence id + // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE + // shape: [n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // // Vocab //