Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc: free buffer after client disconnect #7378

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <list>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
Expand Down Expand Up @@ -731,7 +732,7 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint

// RPC server-side implementation

static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
static ggml_backend_buffer_t rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
// input serialization format: | size (8 bytes) |
uint64_t size;
memcpy(&size, input.data(), sizeof(size));
Expand All @@ -744,6 +745,7 @@ static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t>
output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
return buffer;
}

static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
Expand Down Expand Up @@ -777,13 +779,14 @@ static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<
memcpy(output.data(), &base_ptr, sizeof(base_ptr));
}

static void rpc_free_buffer(const std::vector<uint8_t> & input) {
static ggml_backend_buffer_t rpc_free_buffer(const std::vector<uint8_t> & input) {
// input serialization format: | remote_ptr (8 bytes) |
uint64_t remote_ptr;
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
ggml_backend_buffer_free(buffer);
return buffer;
}

static void rpc_buffer_clear(const std::vector<uint8_t> & input) {
Expand Down Expand Up @@ -917,6 +920,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
}

static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
std::list<ggml_backend_buffer_t> allocated_buffers;
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
Expand All @@ -934,7 +938,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
}
switch (cmd) {
case ALLOC_BUFFER: {
rpc_alloc_buffer(backend, input, output);
allocated_buffers.push_back(rpc_alloc_buffer(backend, input, output));
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the allocated buffer into list.

break;
}
case GET_ALIGNMENT: {
Expand All @@ -950,7 +954,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
break;
}
case FREE_BUFFER: {
rpc_free_buffer(input);
allocated_buffers.remove(rpc_free_buffer(input));
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the freed buffer from list

break;
}
case BUFFER_CLEAR: {
Expand Down Expand Up @@ -993,6 +997,10 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
break;
}
}

for (auto buff: allocated_buffers) {
ggml_backend_buffer_free(buff);
}
Copy link
Author

@chraac chraac May 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

free the reminding buffers.

}

void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
Expand Down
Loading