Skip to content

llama : quantize up to 31% faster on Linux with mmap #3206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6027,7 +6027,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
nthread = std::thread::hardware_concurrency();
}

llama_model_loader ml(fname_inp, /*use_mmap*/ false);
// mmap consistently increases speed Linux, and also increases speed on Windows with
// hot cache. It may cause a slowdown on macOS, possibly related to free memory.
#if defined(__linux__) || defined(_WIN32)
constexpr bool use_mmap = true;
#else
constexpr bool use_mmap = false;
#endif

llama_model_loader ml(fname_inp, use_mmap);
if (ml.use_mmap) {
ml.mapping.reset(new llama_mmap(&ml.file, /* prefetch */ 0, ggml_is_numa()));
}

llama_model model;
llm_load_arch(ml, model);
Expand Down Expand Up @@ -6105,10 +6116,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s

const std::string name = ggml_get_name(tensor);

if (read_data.size() < ggml_nbytes(tensor)) {
read_data.resize(ggml_nbytes(tensor));
if (!ml.use_mmap) {
if (read_data.size() < ggml_nbytes(tensor)) {
read_data.resize(ggml_nbytes(tensor));
}
tensor->data = read_data.data();
}
tensor->data = read_data.data();
ml.load_data_for(tensor);

LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
Expand Down