Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quantize: add imatrix and dataset metadata in GGUF #6658

Merged
merged 16 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ batched-bench: examples/batched-bench/batched-bench.cpp build-info.o ggml.
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

quantize: examples/quantize/quantize.cpp build-info.o ggml.o llama.o $(OBJS)
quantize: examples/quantize/quantize.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

Expand Down
88 changes: 49 additions & 39 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,54 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
return result;
}

bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
const char * sep = strchr(data, '=');
if (sep == nullptr || sep - data >= 128) {
fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data);
return false;
}
llama_model_kv_override kvo;
std::strncpy(kvo.key, data, sep - data);
kvo.key[sep - data] = 0;
sep++;
if (strncmp(sep, "int:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.val_i64 = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) {
sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.val_f64 = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) {
kvo.val_bool = true;
} else if (std::strcmp(sep, "false") == 0) {
kvo.val_bool = false;
} else {
fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data);
return false;
}
} else if (strncmp(sep, "str:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
if (strlen(sep) > 127) {
fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
return false;
}
strncpy(kvo.val_str, sep, 127);
kvo.val_str[127] = '\0';
} else {
fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data);
return false;
}
overrides.emplace_back(std::move(kvo));
return true;
}

bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
llama_sampling_params& sparams = params.sparams;
llama_sampling_params & sparams = params.sparams;

if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
Expand Down Expand Up @@ -1240,47 +1286,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
invalid_param = true;
return true;
}
char* sep = strchr(argv[i], '=');
if (sep == nullptr || sep - argv[i] >= 128) {
fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
invalid_param = true;
return true;
}
struct llama_model_kv_override kvo;
std::strncpy(kvo.key, argv[i], sep - argv[i]);
kvo.key[sep - argv[i]] = 0;
sep++;
if (strncmp(sep, "int:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = std::atol(sep);
}
else if (strncmp(sep, "float:", 6) == 0) {
sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.float_value = std::atof(sep);
}
else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true;
}
else if (std::strcmp(sep, "false") == 0) {
kvo.bool_value = false;
}
else {
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
invalid_param = true;
return true;
}
}
else {
if (!parse_kv_override(argv[i], params.kv_overrides)) {
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
invalid_param = true;
return true;
}
params.kv_overrides.push_back(kvo);
return true;
}
#ifndef LOG_DISABLE_LOGS
Expand Down Expand Up @@ -1551,7 +1561,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -ptc N, --print-token-count N\n");
printf(" print token count every N tokens (default: %d)\n", params.n_print);
printf("\n");
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ struct gpt_params {
std::string image = ""; // path to an image file
};

bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);

bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);

bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
Expand Down
77 changes: 44 additions & 33 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct Stats {
};

struct StatParams {
std::string dataset;
std::string ofile = "imatrix.dat";
int n_output_frequency = 10;
int verbosity = 1;
Expand All @@ -46,7 +47,7 @@ class IMatrixCollector {
std::vector<float> m_src1_data;
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
//
void save_imatrix(const char * file_name) const;
void save_imatrix(const char * file_name, const char * dataset) const;
void keep_imatrix(int ncall) const;
};

Expand Down Expand Up @@ -199,32 +200,41 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
}

void IMatrixCollector::save_imatrix() const {
save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str());
save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str(), m_params.dataset.c_str());
}

void IMatrixCollector::keep_imatrix(int ncall) const {
auto file_name = m_params.ofile;
if (file_name.empty()) file_name = "imatrix.dat";
file_name += ".at_";
file_name += std::to_string(ncall);
save_imatrix(file_name.c_str());
save_imatrix(file_name.c_str(), m_params.dataset.c_str());
}

void IMatrixCollector::save_imatrix(const char * fname) const {
void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) const {
std::ofstream out(fname, std::ios::binary);
int n_entries = m_stats.size();
out.write((const char*)&n_entries, sizeof(n_entries));
for (auto& p : m_stats) {
out.write((const char *) &n_entries, sizeof(n_entries));
for (const auto & p : m_stats) {
int len = p.first.size();
out.write((const char*)&len, sizeof(len));
out.write((const char *) &len, sizeof(len));
out.write(p.first.c_str(), len);
out.write((const char*)&p.second.ncall, sizeof(p.second.ncall));
out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
int nval = p.second.values.size();
out.write((const char*)&nval, sizeof(nval));
if (nval > 0) out.write((const char*)p.second.values.data(), nval*sizeof(float));
out.write((const char *) &nval, sizeof(nval));
if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float));
}

// Write the number of call the matrix was computed with
out.write((const char *) &m_last_call, sizeof(m_last_call));

// Write the dataset name at the end of the file to later on specify it in quantize
int n_dataset = strlen(dataset);
out.write((const char *) &n_dataset, sizeof(n_dataset));
out.write(dataset, n_dataset);

if (m_params.verbosity > 0) {
fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n",__func__,m_last_call,fname);
fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname);
}
}

Expand Down Expand Up @@ -547,6 +557,29 @@ int main(int argc, char ** argv) {
}
}

gpt_params params;
params.n_batch = 512;
if (!gpt_params_parse(args.size(), args.data(), params)) {
return 1;
}

params.logits_all = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);

print_build_info();

if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}

fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);

std::mt19937 rng(params.seed);
if (params.random_prompt) {
params.prompt = gpt_random_prompt(rng);
}

sparams.dataset = params.prompt_file;
g_collector.set_parameters(std::move(sparams));

if (!combine_files.empty()) {
Expand Down Expand Up @@ -585,28 +618,6 @@ int main(int argc, char ** argv) {
}
}

gpt_params params;
params.n_batch = 512;
if (!gpt_params_parse(args.size(), args.data(), params)) {
return 1;
}

params.logits_all = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);

print_build_info();

if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}

fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);

std::mt19937 rng(params.seed);
if (params.random_prompt) {
params.prompt = gpt_random_prompt(rng);
}

llama_backend_init();
llama_numa_init(params.numa);

Expand Down
2 changes: 1 addition & 1 deletion examples/quantize/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
set(TARGET quantize)
add_executable(${TARGET} quantize.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_include_directories(${TARGET} PRIVATE ../../common)
target_compile_features(${TARGET} PRIVATE cxx_std_11)
Loading
Loading