-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proof of concept: GPU-accelerated token generation #1375
Changes from all commits
229aa1f
d052a0e
3ed4588
8a9d7ce
c46320d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -225,6 +225,45 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) { | |
} | ||
} | ||
|
||
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) { | ||
const block_q4_0 * x = (const block_q4_0 *) vx; | ||
|
||
const int row = blockIdx.x; | ||
const int tid = threadIdx.x; | ||
|
||
float partial_sum = 0; // separate sum for each thread | ||
|
||
for (int i = 0; i < ncols/block_size; i += 2) { | ||
const int col = i*block_size + 2*tid; | ||
|
||
// dequantize | ||
const float d = x[(row*ncols + col)/QK4_0].d; | ||
|
||
const uint8_t * pp = x[(row*ncols + col)/QK4_0].qs; | ||
|
||
const uint8_t vui = pp[((row*ncols + col)%QK4_0)/2]; | ||
|
||
const int8_t vi0 = vui & 0xF; | ||
const int8_t vi1 = vui >> 4; | ||
|
||
const float v0 = (vi0 - 8)*d; | ||
const float v1 = (vi1 - 8)*d; | ||
|
||
// matrix multiplication | ||
partial_sum += v0 * y[col + 0]; | ||
partial_sum += v1 * y[col + 1]; | ||
} | ||
|
||
// sum up partial sums and write back result | ||
#pragma unroll | ||
for (int mask=16; mask > 0; mask >>= 1) { | ||
partial_sum += __shfl_xor_sync(0xffffffff, partial_sum, mask, 32); | ||
} | ||
if (tid == 0) { | ||
dst[row] = partial_sum; | ||
} | ||
} | ||
|
||
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | ||
const int nb = k / QK4_0; | ||
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y); | ||
|
@@ -255,6 +294,23 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre | |
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y); | ||
} | ||
|
||
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably be renamed to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that something along the lines of |
||
// static int block_size = -1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest to remove commented code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Already done for the version that I intend to get merged: #1412 |
||
// if (block_size == -1) { | ||
// int min_grid_size, max_block_size = 1; | ||
// CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0)); | ||
// max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE); | ||
// block_size = 1; | ||
// while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) { | ||
// block_size *= 2; | ||
// } | ||
// } | ||
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols); | ||
const int block_size = 32; | ||
GGML_ASSERT(ncols % block_size == 0); | ||
dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols); | ||
} | ||
|
||
// TODO: optimize | ||
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { | ||
const half * x = (const half *) vx; | ||
|
@@ -290,7 +346,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { | |
} | ||
|
||
// buffer pool for cuda | ||
#define MAX_CUDA_BUFFERS 16 | ||
#define MAX_CUDA_BUFFERS 256 | ||
|
||
struct scoped_spin_lock { | ||
std::atomic_flag& lock; | ||
|
@@ -597,7 +653,10 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor | |
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type); | ||
|
||
size_t x_size, y_size, d_size, q_size; | ||
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); | ||
float * d_X; | ||
if (ne11 > 1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggest abstracting a function to determine if we should use CuBlas to compute. ne11>1 is one of the criterial. we can do more test to see if there is other cases. |
||
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); | ||
} | ||
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); | ||
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); | ||
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size); | ||
|
@@ -612,31 +671,54 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor | |
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS]; | ||
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS]; | ||
|
||
float * c_X = d_X + i * x_ne; | ||
float * c_Y = d_Y + i * y_ne; | ||
float * c_D = d_D + i * d_ne; | ||
char * c_Q = d_Q + i * q_sz; | ||
|
||
// copy src0 and convert to fp32 on device | ||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); | ||
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); | ||
CUDA_CHECK(cudaGetLastError()); | ||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); | ||
// copy src0 to device if necessary | ||
if (src0->backend == GGML_BACKEND_CPU) { | ||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); | ||
} else if (src0->backend == GGML_BACKEND_CUDA) { | ||
c_Q = ((char *) src0->data) + i * q_sz; | ||
} else { | ||
GGML_ASSERT(false); | ||
} | ||
if (ne11 == 1) { | ||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); | ||
|
||
// copy src1 to device | ||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); | ||
// copy src1 to device | ||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); | ||
|
||
// wait for conversion | ||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); | ||
// wait for data | ||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); | ||
|
||
// compute | ||
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); | ||
CUBLAS_CHECK( | ||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, | ||
ne01, ne11, ne10, | ||
&alpha, c_X, ne00, | ||
c_Y, ne10, | ||
&beta, c_D, ne01)); | ||
// compute | ||
dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream); | ||
CUDA_CHECK(cudaGetLastError()); | ||
|
||
} else { | ||
float * c_X = d_X + i * x_ne; | ||
|
||
// convert src0 to fp32 on device | ||
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); | ||
CUDA_CHECK(cudaGetLastError()); | ||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); | ||
|
||
// copy src1 to device | ||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); | ||
|
||
// wait for conversion | ||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); | ||
|
||
// compute | ||
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); | ||
CUBLAS_CHECK( | ||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, | ||
ne01, ne11, ne10, | ||
&alpha, c_X, ne00, | ||
c_Y, ne10, | ||
&beta, c_D, ne01)); | ||
} | ||
|
||
// copy dst to host | ||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); | ||
|
@@ -645,7 +727,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor | |
} | ||
|
||
CUDA_CHECK(cudaDeviceSynchronize()); | ||
ggml_cuda_pool_free(d_X, x_size); | ||
if (ne11 > 1) { | ||
ggml_cuda_pool_free(d_X, x_size); | ||
} | ||
ggml_cuda_pool_free(d_Y, y_size); | ||
ggml_cuda_pool_free(d_D, d_size); | ||
ggml_cuda_pool_free(d_Q, q_size); | ||
|
@@ -661,8 +745,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te | |
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && | ||
src1->type == GGML_TYPE_F32 && | ||
dst->type == GGML_TYPE_F32 && | ||
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { | ||
|
||
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) { | ||
return true; | ||
} | ||
|
||
|
@@ -714,3 +797,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct | |
return 0; | ||
} | ||
} | ||
|
||
void ggml_cuda_transform_tensor(ggml_tensor * tensor) { | ||
const int64_t ne0 = tensor->ne[0]; | ||
const int64_t ne1 = tensor->ne[1]; | ||
const int64_t ne2 = tensor->ne[2]; | ||
const int64_t ne3 = tensor->ne[3]; | ||
|
||
const ggml_type type = tensor->type; | ||
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); | ||
|
||
size_t q_size; | ||
char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); | ||
|
||
cudaStream_t cudaStream2 = g_cudaStreams2[0]; | ||
|
||
// copy tensor to device | ||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); | ||
CUDA_CHECK(cudaDeviceSynchronize()); | ||
|
||
tensor->data = d_Q; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will we have a memory leak here? Do we need deallocate the original buffer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is at least not more of a memory leak than master where the buffers are actually used. I plan to make a version that frees the memory in RAM (or rather doesn't allocate as much RAM in the first place) but my priority was getting a working proof of concept out. |
||
tensor->backend = GGML_BACKEND_CUDA; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,9 @@ | |
#include "llama.h" | ||
|
||
#include "ggml.h" | ||
#ifdef GGML_USE_CUBLAS | ||
#include "ggml-cuda.h" | ||
#endif | ||
|
||
#include <array> | ||
#include <ctime> | ||
|
@@ -815,6 +818,7 @@ struct llama_context_params llama_context_default_params() { | |
/*.vocab_only =*/ false, | ||
/*.use_mmap =*/ true, | ||
/*.use_mlock =*/ false, | ||
/*.gpu_layers =*/ 0, | ||
/*.embedding =*/ false, | ||
/*.progress_callback =*/ nullptr, | ||
/*.progress_callback_user_data =*/ nullptr, | ||
|
@@ -877,6 +881,7 @@ static void llama_model_load_internal( | |
ggml_type memory_type, | ||
bool use_mmap, | ||
bool use_mlock, | ||
int gpu_layers, | ||
bool vocab_only, | ||
llama_progress_callback progress_callback, | ||
void * progress_callback_user_data) { | ||
|
@@ -1011,6 +1016,18 @@ static void llama_model_load_internal( | |
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); | ||
|
||
model.mapping = std::move(ml->mapping); | ||
#ifdef GGML_USE_CUBLAS | ||
for (int i = 0; i < std::min(gpu_layers, int(hparams.n_layer)); ++i) { | ||
auto & layer = model.layers[i]; | ||
ggml_cuda_transform_tensor(layer.wq); | ||
ggml_cuda_transform_tensor(layer.wk); | ||
ggml_cuda_transform_tensor(layer.wv); | ||
ggml_cuda_transform_tensor(layer.wo); | ||
ggml_cuda_transform_tensor(layer.w1); | ||
ggml_cuda_transform_tensor(layer.w2); | ||
ggml_cuda_transform_tensor(layer.w3); | ||
} | ||
#endif | ||
Comment on lines
+1019
to
+1030
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to to have this be a generic API from ggml instead of something tied to CUDA? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe? The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see your point, CL would additionally need an offset, maybe. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about this: the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've been working on this same problem this week, albeit from a different angle. I've been experimenting with compiling llama.cpp for lower-end devices with integrated GPUs, such as Nvidia Tegra. On those devices, the memory is physically shared between the CPU and the GPU, allowing us to use a unified address space and leveraging zero-copy when doing GPU processing. I see a lot of similarities between your approach and things that I've been experimenting with. One issue that I'm running into is that the This is the sort of thing that I'm attempting to do, but it's not fully operational yet:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If tensors are kept on device, will that significantly hurt performance when the CPU operates on the graph? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the KV cache, it is not read by the CPU, but currently the CPU writes to it, when a batch of embeddings is processed on per layer. EDIT: they are not read by the CPU assuming that the matrix multiplications that they are used in are never done on the CPU, like in this patch, for example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. // important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); If With There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If we did this, then it would certainly make the zero-copy case much easier, because we would no longer be transforming and copying in the same motion. |
||
|
||
// loading time will be recalculate after the first eval, so | ||
// we take page faults deferred by mmap() into consideration | ||
|
@@ -1024,11 +1041,12 @@ static bool llama_model_load( | |
ggml_type memory_type, | ||
bool use_mmap, | ||
bool use_mlock, | ||
int gpu_layers, | ||
bool vocab_only, | ||
llama_progress_callback progress_callback, | ||
void *progress_callback_user_data) { | ||
try { | ||
llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, | ||
llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, gpu_layers, | ||
vocab_only, progress_callback, progress_callback_user_data); | ||
return true; | ||
} catch (const std::string & err) { | ||
|
@@ -2088,7 +2106,7 @@ struct llama_context * llama_init_from_file( | |
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; | ||
|
||
if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type, | ||
params.use_mmap, params.use_mlock, params.vocab_only, | ||
params.use_mmap, params.use_mlock, params.gpu_layers, params.vocab_only, | ||
params.progress_callback, params.progress_callback_user_data)) { | ||
fprintf(stderr, "%s: failed to load model\n", __func__); | ||
llama_free(ctx); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest gpu-layers to align with other options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what you mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
He's suggesting you replace the underscore with a hyphen, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Horenbergerb is right. sorry didn't make it explicit. as other parameters are all using hyphen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not all using it consistently, but it is a different issue.