Skip to content

Commit

Permalink
metal : use shared buffers between CPU and GPU (ggerganov#1696)
Browse files Browse the repository at this point in the history
* Use MTLDevice.newBufferWithBytesNoCopy to share buffers between CPU and GPU

* Page-align buffers used by Metal

* Remove trailing whitespace

* Only import unistd.h for Metal builds

* metal : remove unnecessary copies

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
kiltyj and ggerganov authored Jun 5, 2023
1 parent efe0507 commit 9d0693b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 16 deletions.
17 changes: 14 additions & 3 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,25 @@ bool ggml_metal_add_buffer(
}
}

size_t page_size = getpagesize();
size_t aligned_size = size;
if ((aligned_size % page_size) != 0) {
aligned_size += (page_size - (aligned_size % page_size));
}

ctx->buffers[ctx->n_buffers].name = name;
ctx->buffers[ctx->n_buffers].data = data;
ctx->buffers[ctx->n_buffers].size = size;
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytes:data length:size options:MTLResourceStorageModeShared];
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:aligned_size options:MTLResourceStorageModeShared deallocator:nil];

++ctx->n_buffers;
if (ctx->buffers[ctx->n_buffers].metal == nil) {
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
return false;
} else {
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
}

fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, size / 1024.0 / 1024.0);
++ctx->n_buffers;
}

return true;
Expand Down
8 changes: 8 additions & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
#include <float.h>
#include <limits.h>

#ifdef GGML_USE_METAL
#include <unistd.h>
#endif

// if C99 - static_assert is noop
// ref: https://stackoverflow.com/a/53923785/4039976
#ifndef static_assert
Expand Down Expand Up @@ -122,7 +126,11 @@ typedef void* thread_ret_t;
#else
inline static void* ggml_aligned_malloc(size_t size) {
void* aligned_memory = NULL;
#ifdef GGML_USE_METAL
int result = posix_memalign(&aligned_memory, getpagesize(), size);
#else
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
#endif
if (result != 0) {
// Handle allocation failure
return NULL;
Expand Down
16 changes: 16 additions & 0 deletions llama-util.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,29 @@ struct llama_buffer {
llama_buffer() = default;

void resize(size_t len) {
#ifdef GGML_USE_METAL
free(addr);
int result = posix_memalign((void **) &addr, getpagesize(), len);
if (result == 0) {
memset(addr, 0, len);
}
else {
addr = NULL;
}
#else
delete[] addr;
addr = new uint8_t[len];
#endif
size = len;
}

~llama_buffer() {
#ifdef GGML_USE_METAL
free(addr);
#else
delete[] addr;
#endif
addr = NULL;
}

// disable copy and move
Expand Down
13 changes: 0 additions & 13 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ enum e_model {
MODEL_65B,
};


static const size_t MB = 1024*1024;

// computed for n_ctx == 2048
Expand Down Expand Up @@ -1281,12 +1280,6 @@ static bool llama_eval_internal(
ggml_set_name(embd, "embd");
memcpy(embd->data, tokens, N*ggml_element_size(embd));

#ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) {
ggml_metal_set_tensor(lctx.ctx_metal, embd);
}
#endif

struct ggml_tensor * cur;
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);

Expand Down Expand Up @@ -1484,12 +1477,6 @@ static bool llama_eval_internal(
}

ggml_graph_compute(ctx0, &gf);

if (lctx.ctx_metal) {
// We need to sync the CPU KV cache with the GPU KV cache
ggml_metal_set_tensor(lctx.ctx_metal, kv_self.k);
ggml_metal_set_tensor(lctx.ctx_metal, kv_self.v);
}
}
#else
ggml_graph_compute(ctx0, &gf);
Expand Down

0 comments on commit 9d0693b

Please sign in to comment.