Skip to content

Commit

Permalink
llama : ggml-backend integration (ggerganov#4766)
Browse files Browse the repository at this point in the history
* llama : ggml-backend integration

* ggml-backend : add names to buffers

* fix unmap after loading

* batched-bench : add tensor_split param

* llama : check for null tensor_split

* ggml-backend : increase GGML_MAX_BACKENDS

* improve graph splitting, partial fix for --no-kv-offload

* cuda : add ggml-backend split buffer support

* cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available)

* ggml : fix null backend dereference (ggerganov#4807)

* ggml : fix null backend dereference

* ggml : also check ggml_backend_is_cpu

* test-backend-ops : check buffer allocation failures

* llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row)

* ggml : fix mul_mat_id work size

* llama : rewrite session kv load/set without graphs

* minor

* llama : only initialize used backends, free backends on context free

* llama : abort ctx if cuda backend init fails

* llama : rewrite lora with ggml-backend and compute on CPU

ggml-ci

* llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer

* opencl : add ggml-backend buffer type

* cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf)

* llama : on Metal, by default offload the full model

ggml-ci

* metal : page align the data ptr (ggerganov#4854)

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cuda : fix split buffer free

* address review comments

* llama-bench : add split-mode parameter

* fix whitespace

* opencl : fix double initialization

* server : add --split-mode parameter

* use async copy and compute to improve multi-gpu performance

ggml-ci

* use async memcpys to copy the graph outputs to the CPU

* fix opencl

* use a host buffer for the cpu compute buffer for faster copies to the gpu

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
  • Loading branch information
3 people authored Jan 12, 2024
1 parent 584d674 commit e7e4df0
Show file tree
Hide file tree
Showing 21 changed files with 2,523 additions and 2,285 deletions.
65 changes: 39 additions & 26 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params.n_gpu_layers = std::stoi(argv[i]);
#else
#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
Expand All @@ -554,9 +553,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params.n_gpu_layers_draft = std::stoi(argv[i]);
#else
#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
Expand All @@ -565,40 +563,53 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS
params.main_gpu = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
#endif
#ifndef GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the main GPU has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--split-mode" || arg == "-sm") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string arg_next = argv[i];
if (arg_next == "none") {
params.split_mode = LLAMA_SPLIT_NONE;
} else if (arg_next == "layer") {
params.split_mode = LLAMA_SPLIT_LAYER;
} else if (arg_next == "row") {
params.split_mode = LLAMA_SPLIT_ROW;
} else {
invalid_param = true;
break;
}
#ifndef GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--tensor-split" || arg == "-ts") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS
std::string arg_next = argv[i];

// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);

if (split_arg.size() >= LLAMA_MAX_DEVICES) {
invalid_param = true;
break;
}
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
if (i < split_arg.size()) {
params.tensor_split[i] = std::stof(split_arg[i]);
} else {
params.tensor_split[i] = 0.0f;
}
}
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mul-mat-q" || arg == "-nommq") {
#ifdef GGML_USE_CUBLAS
params.mul_mat_q = false;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
#ifndef GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting a tensor split has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mmap") {
params.use_mmap = false;
Expand Down Expand Up @@ -915,14 +926,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" number of layers to store in VRAM\n");
printf(" -ngld N, --n-gpu-layers-draft N\n");
printf(" number of layers to store in VRAM for the draft model\n");
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
printf(" how to split the model across multiple GPUs, one of:\n");
printf(" - none: use one GPU only\n");
printf(" - layer (default): split layers and KV across GPUs\n");
printf(" - row: split rows across GPUs\n");
printf(" -ts SPLIT, --tensor-split SPLIT\n");
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
#ifdef GGML_USE_CUBLAS
printf(" -nommq, --no-mul-mat-q\n");
printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
printf(" Not recommended since this is both slower and uses more VRAM.\n");
#endif // GGML_USE_CUBLAS
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
#endif
printf(" -gan N, --grp-attn-n N\n");
printf(" group-attention factor (default: %d)\n", params.grp_attn_n);
Expand Down Expand Up @@ -1041,6 +1053,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.n_gpu_layers = params.n_gpu_layers;
}
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct gpt_params {
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_beams = 0; // if non-zero then use beam search of given width.
Expand Down
3 changes: 3 additions & 0 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ int main(int argc, char ** argv) {

llama_model_params model_params = llama_model_default_params();

const std::vector<float> t_split (LLAMA_MAX_DEVICES, 0.0f);

model_params.n_gpu_layers = n_gpu_layers;
model_params.tensor_split = t_split.data();

llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);

Expand Down
Loading

0 comments on commit e7e4df0

Please sign in to comment.