-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml : add RPC backend #761
Conversation
Very cool! To implement tensor map_tensor(t):
if t in tensor_map:
return tensor_map[t]
tensor new_t = t
new_t->view_src = map_tensor(t->view_src)
new_t->buffer = map_buffer(t->buffer)
for i in range(GGML_MAX_SRC):
new_t->src[i] = map_tensor(t->src[i])
tensor_map[t] = new_t
return new_t
}
for i in range(remote_graph->n_nodes):
local_graph->nodes[i] = map_tensor(remote_graph->nodes[i]) For the CUDA backend, it is important to also wrap the buffer |
Thanks for the hints. I have added a |
The purpose of the RPC backend is to proxy all operations to another host where they are implemented with one of the existing backends (e.g. CUDA, Metal, etc.).
@slaren Thanks for the review, I have addressed your comments. The simple example which multiplies two tensors is working. I am currently trying to make gpt-2 work with this backend. Currently it is producing garbage:
I suspect there is something wrong with the way I reconstruct the compute graph but I am still debugging ... |
What backend are you wrapping in the server? The CPU backend should be the simplest to make work. |
I am wrapping the CPU backend on the server. |
The issue may be that |
src/ggml-rpc.cpp
Outdated
} | ||
result->flags = protobuf_tensor.flags(); | ||
result->data = reinterpret_cast<void *>(protobuf_tensor.data()); | ||
strncpy(result->name, protobuf_tensor.name().c_str(), GGML_MAX_NAME); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice this will work as long as the server and the client have the same GGML_MAX_NAME
, but it is not safe to use strncpy
in this way because it does not guarantee that the string will be NUL-terminated. snprintf
should be safe.
I get this error when building (followed by a million more):
It seems that gRPC requires building with C++17. Am I missing something? |
Build gRPC by adding |
That was exactly the problem, gpt-2 works now! |
Does it work with 11 too? Usually ggml targets C++11. |
gRPC doesn't build with |
I got it to work now, very nice. Seems to work fine with CUDA as well. It is very slow, but I guess this is because all the GetAllocSize/BufferGetBase/InitTensor calls. I think that the BuferGetBase function could be cached in the client, and the other functions could be buffered until there is a call that uses the tensor, such as a tensor_set, tensor_get or a graph_compute, and then submitted to the server in a large batch. |
Yes, these exact 3 functions are called many many times.
The same prompt with CPU backend gives:
|
Actually |
This is the performance that I get after removing these calls: 117M 1558M diff --git a/src/ggml-alloc.c b/src/ggml-alloc.c
index e675306..8e72957 100644
--- a/src/ggml-alloc.c
+++ b/src/ggml-alloc.c
@@ -369,6 +369,7 @@ struct node_alloc {
struct ggml_gallocr {
ggml_backend_buffer_type_t * bufts; // [n_buffers]
ggml_backend_buffer_t * buffers; // [n_buffers]
+ void ** buffer_bases; // [n_buffers]
struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
int n_buffers;
@@ -392,6 +393,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs
galloc->buffers = calloc(sizeof(ggml_backend_buffer_t) * n_bufs, 1);
GGML_ASSERT(galloc->buffers != NULL);
+ galloc->buffer_bases = calloc(sizeof(void *) * n_bufs, 1);
+ GGML_ASSERT(galloc->buffer_bases != NULL);
+
galloc->buf_tallocs = calloc(sizeof(struct ggml_dyn_tallocr *) * n_bufs, 1);
GGML_ASSERT(galloc->buf_tallocs != NULL);
@@ -733,6 +737,7 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
#endif
ggml_backend_buffer_free(galloc->buffers[i]);
galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
+ galloc->buffer_bases[i] = ggml_backend_buffer_get_base(galloc->buffers[i]);
if (galloc->buffers[i] == NULL) {
fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
return false;
@@ -763,7 +768,7 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
if (node->data == NULL) {
assert(tensor_alloc->offset != SIZE_MAX);
assert(ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], node) <= tensor_alloc->size_max);
- void * base = ggml_backend_buffer_get_base(galloc->buffers[buffer_id]);
+ void * base = galloc->buffer_bases[buffer_id];
void * addr = (char *)base + tensor_alloc->offset;
ggml_backend_tensor_alloc(galloc->buffers[buffer_id], node, addr);
} else {
diff --git a/src/ggml-backend.c b/src/ggml-backend.c
index d60d984..3b56b4b 100644
--- a/src/ggml-backend.c
+++ b/src/ggml-backend.c
@@ -1637,20 +1637,20 @@ void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * t
tensor->buffer = buffer;
tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
tensor->backend = tensor->view_src->backend;
- ggml_backend_buffer_init_tensor(buffer, tensor);
+ //ggml_backend_buffer_init_tensor(buffer, tensor);
}
void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
GGML_ASSERT(tensor->buffer == NULL);
GGML_ASSERT(tensor->data == NULL);
GGML_ASSERT(tensor->view_src == NULL);
- GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
- GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
- (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
+ //GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
+ //GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
+ // (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
tensor->buffer = buffer;
tensor->data = addr;
- ggml_backend_buffer_init_tensor(buffer, tensor);
+ //ggml_backend_buffer_init_tensor(buffer, tensor);
}
static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
diff --git a/src/ggml-rpc.cpp b/src/ggml-rpc.cpp
index 0846429..e5592cd 100644
--- a/src/ggml-rpc.cpp
+++ b/src/ggml-rpc.cpp
@@ -255,6 +255,7 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_
}
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+ return ggml_nbytes(tensor);
GGML_PRINT_DEBUG("get alloc size\n");
ggml::GetAllocSizeRequest request;
ggml::Tensor * protobuf_tensor = request.mutable_tensor();
@@ -333,7 +334,7 @@ static void add_node(ggml::GraphComputeRequest & request, ggml_tensor * node, st
add_node(request, node->view_src, visited);
ggml::Tensor * protobuf_tensor = request.add_tensors();
- GGML_PRINT_DEBUG("add node: %p\n", (void*)node);
+ //GGML_PRINT_DEBUG("add node: %p\n", (void*)node);
serialize_tensor(node, protobuf_tensor);
}
@@ -574,7 +575,7 @@ static struct ggml_tensor * create_node(uint64_t id,
}
for (int i = 0; i < request->tensors_size(); i++) {
if (request->tensors(i).id() == id) {
- GGML_PRINT_DEBUG("create node: %lx\n", id);
+ //GGML_PRINT_DEBUG("create node: %lx\n", id);
const ggml::Tensor & protobuf_tensor = request->tensors(i);
struct ggml_tensor * result = deserialize_tensor(ctx, protobuf_tensor);
tensor_map[id] = result; |
With the backend implementation now supporting pipeline parallelism, it seems possible to extend this RPC backend to perform distributed inference across many devices. This would be advantageous compared to the MPI backend because the latter does not support pipeline parallelisation and it is not obvious how to implement it. Are there any obvious blockers? If not, maybe we should put this on the roadmap and try to support it eventually. It would be a cool technical feat and might even unlock some interesting inference use cases |
It should be doable, but pipeline parallelism requires the ability to perform asynchronous copies between backends, and asynchronous event synchronization between backends, and it could be tricky to implement that. Servers would probably need to be able to communicate between themselves to do this. |
I believe distributed ggml would be a huge win, especially for very large models like Grok. Async operations are not a problem with gRPC but I need to get more familiar with the pipeline parallelism. In any case, I think this would be much better compared to MPI in the long term. |
Superseded by: ggml-org/llama.cpp#6829 |
We have a use case where we want to build and run ggml programs on low end machines (without GPUs) and leverage the computational resources of some high end machines (with GPUs) over the network. In this PR I am trying to prototype an RPC backend which proxies all operations to another host. On the remote host, the RPC backend simply delegates to one of the existing backends (CUDA, Metal, etc.):
I am using gRPC for the remote calls, you can find the interface definition in
ggml-rpc.proto
. I have a simple program (client.cpp) which creates some tensors and successfully stores and retrieves data into them using the RPC backend.You can give it a try with the following steps:
With this configuration the RPC backend will delegate to the CUDA backend
3. Start the RPC server:
I am currently looking for some guidance on how to implement
graph_compute
with this approach. Any help is appreciated.