Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add retrieval example #6193

Merged
merged 21 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ models-mnt
/batched-bench
/export-lora
/finetune
/retrieval
/speculative
/parallel
/train-text-from-scratch
Expand Down
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
BUILD_TARGETS = \
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
retrieval speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o

# Binaries only useful for tests
TEST_TARGETS = \
Expand Down Expand Up @@ -794,6 +794,10 @@ export-lora: examples/export-lora/export-lora.cpp ggml.o common/common.h $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

retrieval: examples/retrieval/retrieval.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

speculative: examples/speculative/speculative.cpp ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
Expand Down
42 changes: 42 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,43 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
}
return true;
}
if (arg == "--context-files") {
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
if (++i >= argc) {
invalid_param = true;
return true;
}
while(true) {
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
// store the external file name in params
params.context_files.push_back(argv[i]);
if (i + 1 >= argc || argv[i + 1][0] == '-') {
break;
}
i++;
}
return true;
}
if (arg == "--chunk-size") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.chunk_size = std::stoi(argv[i]);
return true;
}
if (arg == "--chunk-separator") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.chunk_separator = argv[i];
return true;
}
if (arg == "-n" || arg == "--n-predict") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -1282,6 +1319,11 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" prompt file to start generation.\n");
printf(" -bf FNAME, --binary-file FNAME\n");
printf(" binary file containing multiple choice tasks.\n");
printf(" --context-files FNAME1 FNAME2...\n");
printf(" files containing context to embed.\n");
printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size);
ngxson marked this conversation as resolved.
Show resolved Hide resolved
printf(" --chunk-separator STRING\n");
printf(" string to separate chunks (default: newline)\n");
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch);
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ struct gpt_params {
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold
std::vector<std::string> context_files = {}; // context files to embed
int32_t chunk_size = 64; // chunk size for context embedding
std::string chunk_separator = "\n"; // chunk separator for context embedding

ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ else()
add_subdirectory(perplexity)
add_subdirectory(quantize)
add_subdirectory(quantize-stats)
add_subdirectory(retrieval)
add_subdirectory(save-load-state)
add_subdirectory(simple)
add_subdirectory(passkey)
Expand Down
5 changes: 5 additions & 0 deletions examples/retrieval/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET retrieval)
add_executable(${TARGET} retrieval.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
6 changes: 6 additions & 0 deletions examples/retrieval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# llama.cpp/examples/retrieval

Demonstration of simple retrieval technique based on cosin similarity

More info:
https://github.com/ggerganov/llama.cpp/pull/6193
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
278 changes: 278 additions & 0 deletions examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
#include "common.h"
#include "llama.h"

#include <algorithm>
#include <fstream>

struct chunk {
// filename
std::string filename;
// original file position
int64_t filepos;
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
// original text data
std::string textdata = "";
// tokenized text data
std::vector<std::int32_t> tokens;
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
// embedding
std::vector<float> embedding;
// cosin similarity
float similarity;
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
};

// chunk file data to chunks of size >= chunk_size
// chunk_separator is the separator between chunks
static std::vector<chunk> chunk_file(const std::string filename, int chunk_size, std::string chunk_separator) {
std::vector<chunk> chunks;
std::ifstream f(filename.c_str());

if (!f.is_open()) {
fprintf(stderr, "Error: could not open file %s\n", filename.c_str());
return chunks;
}

chunk current_chunk;
char buffer[chunk_size];
int64_t filepos = 0;
std::string current = "";
while (f.read(buffer, chunk_size)) {
current += std::string(buffer, f.gcount());
size_t pos;
while ((pos = current.find(chunk_separator)) != std::string::npos) {
current_chunk.textdata += current.substr(0, pos + chunk_separator.size());
if ((int) current_chunk.textdata.size() > chunk_size) {
// save chunk
current_chunk.filepos = filepos;
current_chunk.filename = filename;
chunks.push_back(current_chunk);
// update filepos
filepos += (int) current_chunk.textdata.size();
// reset current_chunk
current_chunk = chunk();
}
current = current.substr(pos + chunk_separator.size());
}

}
// add leftover data to last chunk
if (current_chunk.textdata.size() > 0) {
if (chunks.empty()) {
current_chunk.filepos = filepos;
current_chunk.filename = filename;
chunks.push_back(current_chunk);
} else {
chunks.back().textdata += current_chunk.textdata;
}
}
f.close();
return chunks;
}

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
}
}

static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);

// run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_decode(ctx, batch) < 0) {
fprintf(stderr, "%s : failed to decode\n", __func__);
}

for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}

// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}

float * out = output + batch.seq_id[i][0] * n_embd;
llama_embd_normalize(embd, out, n_embd);
}
}

int main(int argc, char ** argv) {
gpt_params params;

if (!gpt_params_parse(argc, argv, params)) {
return 1;
}

if (params.chunk_size <= 0) {
fprintf(stderr, "chunk_size must be positive\n");
return 1;
}
if (params.context_files.empty()) {
fprintf(stderr, "context_files must be specified\n");
return 1;
}
params.embedding = true;

print_build_info();

if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
}

printf("processing files:\n");
for (auto & context_file : params.context_files) {
printf("%s\n", context_file.c_str());
}

std::vector<chunk> chunks;
for (auto & context_file : params.context_files) {
std::vector<chunk> file_chunk = chunk_file(context_file, params.chunk_size, params.chunk_separator);
chunks.insert(chunks.end(), file_chunk.begin(), file_chunk.end());
}
printf("Number of chunks: %ld\n", chunks.size());

llama_backend_init();
llama_numa_init(params.numa);

llama_model * model;
llama_context * ctx;

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}

const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);

if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx);
}

// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}

// max batch size
const uint64_t n_batch = params.n_batch;
GGML_ASSERT(params.n_batch >= params.n_ctx);

// tokenize the prompts and trim
for (auto & chunk : chunks) {
auto inp = ::llama_tokenize(ctx, chunk.textdata, true, false);
if (inp.size() > n_batch) {
inp.resize(n_batch);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure about this line. Will it drops tokens if inp.size() > n_batch? If that's the case, that means we drop information from embedding, which will cause the output embedding to become inaccurate.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we'll have to restrict the chunk size to fit the n_batch parameter somehow. This presents the identical issue to setting a configurable maximum chunk length: How should we handle situations where separators fail to appear within the maximum length?

Copy link
Collaborator

@ngxson ngxson Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe reversed, we can force the user to give sufficient n_batch (and n_ubatch) before running the app. AFAIK that's because with embeddings, we use non-causal models the requires prompt to be processed in one batch.

My plan is maybe after tokenize all chunks, you can see if any chunk have tokens.size() > n_batch, then raise an error and exit the program.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can see if any chunk have tokens.size() > n_batch, then raise an error and exit the program.

This seems like the best option for now 👍

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't it make more sense to determine params.n_batch based on the largest chunk after tokenization? It should not be a user-provided parameter in this example - just set it to the largest chunk size

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that can be a solution. I'm just not sure if that requires re-creating a new llama_context, because the prior llama_init_from_gpt_params call already used params.n_batch and params.n_ubatch

}
// add eos if not present
if (inp.empty() || inp.back() != llama_token_eos(model)) {
inp.push_back(llama_token_eos(model));
}
chunk.tokens = inp;
}

// tokenization stats
if (params.verbose_prompt) {
for (int i = 0; i < (int) chunks.size(); i++) {
fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, chunks[i].textdata.c_str());
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, chunks[i].tokens.size());
for (int j = 0; j < (int) chunks[i].tokens.size(); j++) {
fprintf(stderr, "%6d -> '%s'\n", chunks[i].tokens[j], llama_token_to_piece(ctx, chunks[i].tokens[j]).c_str());
}
fprintf(stderr, "\n\n");
}
}

// initialize batch
const int n_chunks = chunks.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);

// allocate output
const int n_embd = llama_n_embd(model);
std::vector<float> embeddings(n_chunks * n_embd, 0);
float * emb = embeddings.data();

// break into batches
int p = 0; // number of prompts processed already
int s = 0; // number of prompts in current batch
for (int k = 0; k < n_chunks; k++) {
// clamp to n_batch tokens
auto & inp = chunks[k].tokens;

const uint64_t n_toks = inp.size();

// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
llama_batch_clear(batch);
p += s;
s = 0;
}

// add to batch
batch_add_seq(batch, inp, s);
s += 1;
}

// final batch
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);

// save embeddings to chunks
for (int i = 0; i < n_chunks; i++) {
chunks[i].embedding = std::vector<float>(emb + i * n_embd, emb + (i + 1) * n_embd);
}

// start loop, receive query and return top k similar chunks based on cosine similarity
std::string query;
while (true) {
printf("Enter query: ");
std::getline(std::cin, query);
if (query == "exit" || query == "quit" || query == "q") {
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
break;
}
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);

struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
batch_add_seq(query_batch, query_tokens, 0);
float * query_emb = new float[n_embd];
batch_decode(ctx, query_batch, query_emb, 1, n_embd);
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<float> query_embedding(query_emb, query_emb + n_embd);
delete[] query_emb;
llama_batch_clear(query_batch);

for (int i = 0; i < n_chunks; i++) {
float similarity = llama_embd_similarity_cos(chunks[i].embedding.data(), query_embedding.data(), n_embd);
chunks[i].similarity = similarity;
}
std::sort(chunks.begin(), chunks.end(), [](chunk & a, chunk & b) {
return a.similarity > b.similarity;
});
printf("Top %d similar chunks:\n", params.sparams.top_k);
for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) {
printf("filename: %s\n", chunks[i].filename.c_str());
printf("filepos: %lld\n", chunks[i].filepos);
printf("similarity: %f\n", chunks[i].similarity);
printf("textdata:\n%s\n", chunks[i].textdata.c_str());
printf("--------------------\n");
}
}

// clean up
llama_print_timings(ctx);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
}
Loading