Skip to content

Commit

Permalink
Add embedding mode with arg flag. Currently working (ggerganov#282)
Browse files Browse the repository at this point in the history
* working but ugly

* add arg flag, not working on embedding mode

* typo

* Working! Thanks to @nullhook

* make params argument instead of hardcoded boolean. remove useless time check

* start doing the instructions but not finished. This probably doesnt compile

* Embeddings extraction support

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
StrikingLoo and ggerganov authored Mar 24, 2023
1 parent b6b268d commit 8d4a855
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 10 deletions.
56 changes: 46 additions & 10 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ struct llama_context {
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
bool logits_all = false;

// input embedding (1-dimensional array: [n_embd])
std::vector<float> embedding;
};

struct llama_context_params llama_context_default_params() {
Expand All @@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
/*.f16_kv =*/ false,
/*.logits_all =*/ false,
/*.vocab_only =*/ false,
/*.embedding =*/ false,
};

return result;
Expand Down Expand Up @@ -592,8 +596,6 @@ static bool llama_model_load(
fin.close();
}

lctx.logits.reserve(lctx.model.hparams.n_ctx);

lctx.t_load_us = ggml_time_us() - t_start_us;

return true;
Expand Down Expand Up @@ -791,6 +793,9 @@ static bool llama_eval_internal(
inpL = cur;
}

// used at the end to optionally extract the embeddings
struct ggml_tensor * embeddings = NULL;

// norm
{
inpL = ggml_rms_norm(ctx0, inpL);
Expand All @@ -799,6 +804,8 @@ static bool llama_eval_internal(
inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model.norm, inpL),
inpL);

embeddings = inpL;
}

// lm_head
Expand All @@ -821,15 +828,26 @@ static bool llama_eval_internal(
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);

auto & logits_out = lctx.logits;
// extract logits
{
auto & logits_out = lctx.logits;

if (lctx.logits_all) {
logits_out.resize(n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
} else {
// return result for just the last token
logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
}
}

// extract embeddings
if (lctx.embedding.size()) {
auto & embedding_out = lctx.embedding;

if (lctx.logits_all) {
logits_out.resize(n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
} else {
// return result for just the last token
logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
embedding_out.resize(n_embd);
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
}

if (mem_per_token == 0) {
Expand Down Expand Up @@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file(
return nullptr;
}

// reserve memory for context buffers
{
const auto & hparams = ctx->model.hparams;
if (params.logits_all) {
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
} else {
ctx->logits.reserve(hparams.n_ctx);
}

if (params.embedding){
ctx->embedding.reserve(hparams.n_embd);
}
}

return ctx;
}

Expand Down Expand Up @@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits.data();
}

float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embedding.data();
}

const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
if (token >= llama_n_vocab(ctx)) {
return nullptr;
Expand Down
5 changes: 5 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ extern "C" {
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool embedding; // embedding mode only
};

LLAMA_API struct llama_context_params llama_context_default_params();
Expand Down Expand Up @@ -108,6 +109,10 @@ extern "C" {
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);

// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);

// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);

Expand Down
23 changes: 23 additions & 0 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;

ctx = llama_init_from_file(params.model.c_str(), lparams);

Expand Down Expand Up @@ -292,6 +293,7 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd;


int last_n_size = params.repeat_last_n;
std::vector<llama_token> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
Expand Down Expand Up @@ -324,6 +326,27 @@ int main(int argc, char ** argv) {
// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);

if (params.embedding){
embd = embd_inp;

if (embd.size() > 0) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
}

const auto embeddings = llama_get_embeddings(ctx);

// TODO: print / use the embeddings

if (params.use_color) {
printf(ANSI_COLOR_RESET);
}

return 0;
}

while (remaining_tokens > 0 || params.interactive) {
// predict
if (embd.size() > 0) {
Expand Down
4 changes: 4 additions & 0 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.model = argv[i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "--embedding") {
params.embedding = true;
} else if (arg == "--interactive-start") {
params.interactive = true;
} else if (arg == "--interactive-first") {
params.interactive_start = true;
} else if (arg == "-ins" || arg == "--instruct") {
Expand Down
4 changes: 4 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";


std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted

bool memory_f16 = false; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode

bool embedding = false; // get only sentence embedding
bool interactive_start = false; // wait for user input immediately

bool instruct = false; // instruction mode (used for Alpaca models)
bool ignore_eos = false; // do not stop generating after eos
bool perplexity = false; // compute perplexity over the prompt
Expand Down

0 comments on commit 8d4a855

Please sign in to comment.