Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add retrieval example #6193

Merged
merged 21 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
define retrieval-only parameters in retrieval.cpp
  • Loading branch information
mscheong01 committed Mar 24, 2024
commit 56b7db971e1851964f026a9977d3fbd2e7f46ca1
44 changes: 1 addition & 43 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
return result;
}

static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param) {
bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param) {
std::string arg = argv[i];
llama_sampling_params& sparams = params.sparams;

Expand Down Expand Up @@ -276,43 +276,6 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
}
return true;
}
if (arg == "--context-files") {
if (++i >= argc) {
invalid_param = true;
return true;
}
while(true) {
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
// store the external file name in params
params.context_files.push_back(argv[i]);
if (i + 1 >= argc || argv[i + 1][0] == '-') {
break;
}
i++;
}
return true;
}
if (arg == "--chunk-size") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.chunk_size = std::stoi(argv[i]);
return true;
}
if (arg == "--chunk-separator") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.chunk_separator = argv[i];
return true;
}
if (arg == "-n" || arg == "--n-predict") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -1319,11 +1282,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" prompt file to start generation.\n");
printf(" -bf FNAME, --binary-file FNAME\n");
printf(" binary file containing multiple choice tasks.\n");
printf(" --context-files FNAME1 FNAME2...\n");
printf(" files containing context to embed.\n");
printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size);
printf(" --chunk-separator STRING\n");
printf(" string to separate chunks (default: \"\\n\")\n");
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch);
Expand Down
5 changes: 2 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ struct gpt_params {
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold
std::vector<std::string> context_files;// context files to embed
int32_t chunk_size = 64; // chunk size for context embedding
std::string chunk_separator = "\n"; // chunk separator for context embedding

ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

Expand Down Expand Up @@ -170,6 +167,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

void gpt_print_usage(int argc, char ** argv, const gpt_params & params);

bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param);

std::string get_system_info(const gpt_params & params);

std::string gpt_random_prompt(std::mt19937 & rng);
Expand Down
89 changes: 81 additions & 8 deletions examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,80 @@
#include <algorithm>
#include <fstream>

struct retrieval_params {
std::vector<std::string> context_files;// context files to embed
int32_t chunk_size = 64; // chunk size for context embedding
std::string chunk_separator = "\n"; // chunk separator for context embedding
};

static void retrieval_params_print_usage(int argc, char** argv, gpt_params & gpt_params, retrieval_params & params) {
fprintf(stderr, "usage: retrieval [options]\n");
fprintf(stderr, "options:\n");
printf(" --context-files FNAME1 FNAME2...\n");
printf(" files containing context to embed.\n");
printf(" --chunk-size N minimum length of embedded text chunk (default:%d)\n", params.chunk_size);
printf(" --chunk-separator STRING\n");
printf(" string to separate chunks (default: \"\\n\")\n");
gpt_print_usage(argc, argv, gpt_params);
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
}

static void retrieval_params_parse(int argc, char ** argv, gpt_params & gpt_params, retrieval_params & retrieval_params) {
int i = 1;
std::string arg;
while(i < argc) {
arg = argv[i];
bool invalid_gpt_param = false;
if(gpt_params_find_arg(argc, argv, gpt_params, i, invalid_gpt_param)) {
if (invalid_gpt_param) {
fprintf(stderr, "error: invalid argument: %s\n", arg.c_str());
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
// option was parsed by gpt_params_find_arg
} else if (arg == "--context-files") {
if (++i >= argc) {
fprintf(stderr, "error: missing argument for --context-files\n");
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
while(true) {
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
// store the external file name in params
retrieval_params.context_files.push_back(argv[i]);
if (i + 1 >= argc || argv[i + 1][0] == '-') {
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
break;
}
i++;
}
} else if (arg == "--chunk-size") {
if (++i >= argc) {
fprintf(stderr, "error: missing argument for --chunk-size\n");
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
retrieval_params.chunk_size = std::stoi(argv[i]);
} else if (arg == "--chunk-separator") {
if (++i >= argc) {
fprintf(stderr, "error: missing argument for --chunk-separator\n");
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
retrieval_params.chunk_separator = argv[i];
} else {
// unknown argument
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
retrieval_params_print_usage(argc, argv, gpt_params, retrieval_params);
exit(1);
}
i++;
}
}

struct chunk {
// filename
std::string filename;
Expand Down Expand Up @@ -103,19 +177,18 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu

int main(int argc, char ** argv) {
gpt_params params;
retrieval_params retrieval_params;

if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
retrieval_params_parse(argc, argv, params, retrieval_params);

// For BERT models, batch size must be equal to ubatch size
params.n_ubatch = params.n_batch;

if (params.chunk_size <= 0) {
if (retrieval_params.chunk_size <= 0) {
fprintf(stderr, "chunk_size must be positive\n");
return 1;
}
if (params.context_files.empty()) {
if (retrieval_params.context_files.empty()) {
fprintf(stderr, "context_files must be specified\n");
return 1;
}
Expand All @@ -128,13 +201,13 @@ int main(int argc, char ** argv) {
}

printf("processing files:\n");
for (auto & context_file : params.context_files) {
for (auto & context_file : retrieval_params.context_files) {
printf("%s\n", context_file.c_str());
}

std::vector<chunk> chunks;
for (auto & context_file : params.context_files) {
std::vector<chunk> file_chunk = chunk_file(context_file, params.chunk_size, params.chunk_separator);
for (auto & context_file : retrieval_params.context_files) {
std::vector<chunk> file_chunk = chunk_file(context_file, retrieval_params.chunk_size, retrieval_params.chunk_separator);
chunks.insert(chunks.end(), file_chunk.begin(), file_chunk.end());
}
printf("Number of chunks: %ld\n", chunks.size());
Expand Down