Skip to content

Commit

Permalink
falcon : add CUDA offloading (#2739)
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren authored Aug 23, 2023
1 parent 854ae5d commit e729965
Showing 1 changed file with 101 additions and 11 deletions.
112 changes: 101 additions & 11 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1860,31 +1860,54 @@ static void llm_load_tensors(

// output
{
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU);
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU);
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
ggml_backend backend_norm;
ggml_backend backend_output;

if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#else
backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#endif // _WIN32

backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
} else {
backend_norm = GGML_BACKEND_CPU;
backend_output = GGML_BACKEND_CPU;
}

model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
}

const uint32_t n_ff = hparams.n_ff;

const int i_gpu_start = n_layer - n_gpu_layers;

model.layers.resize(n_layer);

for (uint32_t i = 0; i < n_layer; ++i) {
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT

auto & layer = model.layers[i];

layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, GGML_BACKEND_CPU);
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, GGML_BACKEND_CPU);
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);

if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) {
layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, GGML_BACKEND_CPU);
layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, GGML_BACKEND_CPU);
layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend);
layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend);
}

layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, GGML_BACKEND_CPU);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, GGML_BACKEND_CPU);
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);

layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, GGML_BACKEND_CPU);
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, GGML_BACKEND_CPU);
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
}
} break;
default:
Expand Down Expand Up @@ -2390,6 +2413,8 @@ static struct ggml_cgraph * llm_build_falcon(
const float freq_scale = hparams.rope_freq_scale;
const float norm_eps = hparams.f_norm_eps;

const int n_gpu_layers = model.n_gpu_layers;

auto & buf_compute = lctx.buf_compute;

struct ggml_init_params params = {
Expand Down Expand Up @@ -2430,6 +2455,30 @@ static struct ggml_cgraph * llm_build_falcon(
}
}

const int i_gpu_start = n_layer - n_gpu_layers;
(void) i_gpu_start;

// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
//
// with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
// in that case ggml_cuda_assign_buffers has no effect
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;

#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) {
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 1) {
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 2) {
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS

struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
Expand All @@ -2440,28 +2489,43 @@ static struct ggml_cgraph * llm_build_falcon(
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * attn_norm;

offload_func_t offload_func = llama_nop;

#ifdef GGML_USE_CUBLAS
if (il >= i_gpu_start) {
offload_func = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS

// self-attention
// TODO: refactor into common function (shared with LLaMA)
{
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
offload_func(attn_norm);

attn_norm = ggml_add(ctx0,
ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm),
model.layers[il].attn_norm_b);
offload_func(attn_norm->src[0]);
offload_func(attn_norm);

if (model.layers[il].attn_norm_2) { // Falcon-40B
cur = ggml_norm(ctx0, inpL, norm_eps);
offload_func(cur);

cur = ggml_add(ctx0,
ggml_mul(ctx0, cur, model.layers[il].attn_norm_2),
model.layers[il].attn_norm_2_b);
offload_func(cur->src[0]);
offload_func(cur);
} else { // Falcon 7B
cur = attn_norm;
}

// compute QKV

cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
offload_func_kq(cur);

// Note that the strides for Kcur, Vcur are set up so that the
// resulting views are misaligned with the tensor's storage
Expand All @@ -2479,39 +2543,49 @@ static struct ggml_cgraph * llm_build_falcon(
wsize * n_embd_head,
wsize * n_embd_head * (n_head + 2 * n_head_kv),
0);
offload_func_kq(tmpq);

struct ggml_tensor * tmpk = ggml_view_3d(
ctx0, cur, n_embd_head, n_head_kv, N,
wsize * n_embd_head,
wsize * n_embd_head * (n_head + 2 * n_head_kv),
wsize * n_embd_head * n_head);
offload_func_kq(tmpk);

struct ggml_tensor * tmpv = ggml_view_3d(
ctx0, cur, n_embd_head, n_head_kv, N,
wsize * n_embd_head,
wsize * n_embd_head * (n_head + 2 * n_head_kv),
wsize * n_embd_head * (n_head + n_head_kv));
offload_func_v(tmpv);

// using mode = 2 for neox mode
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
offload_func_kq(Kcur);

{
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N));
offload_func_v(Vcur);
offload_func_v(Vcur->src[0]->src[0]);
ggml_set_name(Vcur, "Vcur");

struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
offload_func_kq(k);
ggml_set_name(k, "k");

struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
offload_func_v(v);

ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}

struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
offload_func_kq(Q);
ggml_set_name(Q, "Q");

struct ggml_tensor * K =
Expand All @@ -2520,18 +2594,23 @@ static struct ggml_cgraph * llm_build_falcon(
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K);
ggml_set_name(K, "K");

struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ");

struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled");

struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");

struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");

struct ggml_tensor * V =
Expand All @@ -2540,18 +2619,23 @@ static struct ggml_cgraph * llm_build_falcon(
ggml_element_size(kv_self.v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
offload_func_v(V);
ggml_set_name(V, "V");

struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
offload_func_v(KQV);
ggml_set_name(KQV, "KQV");

struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
offload_func_v(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged");

cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous");

cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
offload_func(cur);
ggml_set_name(cur, "result_wo");
}

Expand All @@ -2567,13 +2651,18 @@ static struct ggml_cgraph * llm_build_falcon(
// adding this, because there seems to be a bug in the Metal concurrency optimization
// without this line, the results are non-deterministic and wrong
cur->src[2] = attn_out;
offload_func(cur);

cur = ggml_gelu(ctx0, cur);
offload_func(cur);
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
offload_func(cur);
}

cur = ggml_add(ctx0, cur, attn_out);
offload_func(cur);
cur = ggml_add(ctx0, cur, inpL);
offload_func(cur);

// input for next layer
inpL = cur;
Expand All @@ -2584,6 +2673,7 @@ static struct ggml_cgraph * llm_build_falcon(
// norm
{
cur = ggml_norm(ctx0, cur, norm_eps);
offload_func_nr(cur);

cur = ggml_add(ctx0,
ggml_mul(ctx0, cur, model.output_norm),
Expand Down

0 comments on commit e729965

Please sign in to comment.