Skip to content

Commit

Permalink
llama : support batched embeddings (#5466)
Browse files Browse the repository at this point in the history
* batched embedding: pool outputs by sequence id. updated embedding example

* bring back non-causal attention

* embd : minor improvements

* llama : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
iamlemec and ggerganov authored Feb 13, 2024
1 parent ad014bb commit 03bf161
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 52 deletions.
1 change: 1 addition & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_pooling_layer(True)
self.gguf_writer.add_file_type(self.ftype)

def set_vocab(self):
Expand Down
142 changes: 106 additions & 36 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,51 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static std::vector<std::string> split_lines(const std::string & s) {
std::string line;
std::vector<std::string> lines;
std::stringstream ss(s);
while (std::getline(ss, line)) {
lines.push_back(line);
}
return lines;
}

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
}
}

static void normalize(float * vec, float * out, int n) {
float norm = 0;
for (int i = 0; i < n; i++) {
norm += vec[i] * vec[i];
}
norm = sqrt(norm);
for (int i = 0; i < n; i++) {
out[i] = vec[i] / norm;
}
}

static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);

// run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_decode(ctx, batch) < 0) {
fprintf(stderr, "%s : failed to decode\n", __func__);
}

// normalize on copy
for (int k = 0; k < n_seq; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
float * out = output + k * n_embd;
normalize(emb, out, n_embd);
}
}

int main(int argc, char ** argv) {
gpt_params params;

Expand Down Expand Up @@ -55,59 +100,84 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}

int n_past = 0;
// split the prompt into lines
std::vector<std::string> prompts = split_lines(params.prompt);

// tokenize the prompt
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
// max batch size
const uint64_t n_batch = params.n_batch;
GGML_ASSERT(params.n_batch == params.n_ctx);

if (params.verbose_prompt) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
for (int i = 0; i < (int) embd_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
// tokenize the prompts and trim
std::vector<std::vector<int32_t>> inputs;
for (const auto & prompt : prompts) {
auto inp = ::llama_tokenize(ctx, prompt, true);
if (inp.size() > n_batch) {
inp.resize(n_batch);
}
fprintf(stderr, "\n");
inputs.push_back(inp);
}

if (embd_inp.size() > (size_t)n_ctx) {
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
__func__, embd_inp.size(), n_ctx);
return 1;
}

while (!embd_inp.empty()) {
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
// tokenization stats
if (params.verbose_prompt) {
for (int i = 0; i < (int) inputs.size(); i++) {
fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
for (int j = 0; j < (int) inputs[i].size(); j++) {
fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str());
}
fprintf(stderr, "\n\n");
}
n_past += n_tokens;
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
}

// initialize batch
const int n_prompts = prompts.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);

// allocate output
const int n_embd = llama_n_embd(model);
auto * embeddings = llama_get_embeddings(ctx);
std::vector<float> embeddings(n_prompts * n_embd, 0);
float * emb = embeddings.data();

// break into batches
int p = 0; // number of prompts processed already
int s = 0; // number of prompts in current batch
for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens
auto & inp = inputs[k];
const uint64_t n_toks = inp.size();

// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
llama_batch_clear(batch);
p += s;
s = 0;
}

// l2-normalize embeddings
float norm = 0;
for (int i = 0; i < n_embd; i++) {
norm += embeddings[i] * embeddings[i];
}
norm = sqrt(norm);
for (int i = 0; i < n_embd; i++) {
embeddings[i] /= norm;
// add to batch
batch_add_seq(batch, inp, s);
s += 1;
}

for (int i = 0; i < n_embd; i++) {
printf("%f ", embeddings[i]);
// final batch
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);

// print first 3 embeddings
for (int j = 0; j < std::min(3, n_prompts); j++) {
fprintf(stderr, "embedding %d: ", j);
for (int i = 0; i < n_embd; i++) {
fprintf(stderr, "%f ", emb[j * n_embd + i]);
}
fprintf(stderr, "\n\n");
}
printf("\n");
fprintf(stderr, "\n");

// clean up
llama_print_timings(ctx);
llama_free(ctx);
llama_free_model(model);

llama_backend_free();

return 0;
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class LLM:
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count"
POOLING_LAYER = "{arch}.pooling_layer"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ def add_layer_norm_rms_eps(self, value: float) -> None:
def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)

def add_pooling_layer(self, value: bool) -> None:
self.add_bool(Keys.LLM.POOLING_LAYER.format(arch=self.arch), value)

def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)

Expand Down
Loading

0 comments on commit 03bf161

Please sign in to comment.