Skip to content

Commit

Permalink
[libc] GPU RPC interface: add return value to rpc_host_call (#111288)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hardcode84 authored Oct 6, 2024
1 parent 56757e5 commit 26ca8ef
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 12 deletions.
2 changes: 1 addition & 1 deletion libc/newhdrgen/yaml/gpu/rpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ functions:
- name: rpc_host_call
standards:
- GPUExtensions
return_type: void
return_type: unsigned long long
arguments:
- type: void *
- type: void *
Expand Down
2 changes: 1 addition & 1 deletion libc/spec/gpu_ext.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def GPUExtensions : StandardSpec<"GPUExtensions"> {
[
FunctionSpec<
"rpc_host_call",
RetValSpec<VoidType>,
RetValSpec<UnsignedLongLongType>,
[ArgSpec<VoidPtr>, ArgSpec<VoidPtr>, ArgSpec<SizeTType>]
>,
]
Expand Down
9 changes: 7 additions & 2 deletions libc/src/gpu/rpc_host_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@ namespace LIBC_NAMESPACE_DECL {

// This calls the associated function pointer on the RPC server with the given
// arguments. We expect that the pointer here is a valid pointer on the server.
LLVM_LIBC_FUNCTION(void, rpc_host_call, (void *fn, void *data, size_t size)) {
LLVM_LIBC_FUNCTION(unsigned long long, rpc_host_call,
(void *fn, void *data, size_t size)) {
rpc::Client::Port port = rpc::client.open<RPC_HOST_CALL>();
port.send_n(data, size);
port.send([=](rpc::Buffer *buffer) {
buffer->data[0] = reinterpret_cast<uintptr_t>(fn);
});
port.recv([](rpc::Buffer *) {});
unsigned long long ret;
port.recv([&](rpc::Buffer *buffer) {
ret = static_cast<unsigned long long>(buffer->data[0]);
});
port.close();
return ret;
}

} // namespace LIBC_NAMESPACE_DECL
2 changes: 1 addition & 1 deletion libc/src/gpu/rpc_host_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

namespace LIBC_NAMESPACE_DECL {

void rpc_host_call(void *fn, void *buffer, size_t size);
unsigned long long rpc_host_call(void *fn, void *buffer, size_t size);

} // namespace LIBC_NAMESPACE_DECL

Expand Down
9 changes: 7 additions & 2 deletions libc/utils/gpu/server/rpc_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,18 @@ rpc_status_t handle_server_impl(
}
case RPC_HOST_CALL: {
uint64_t sizes[lane_size] = {0};
unsigned long long results[lane_size] = {0};
void *args[lane_size] = {nullptr};
port->recv_n(args, sizes,
[&](uint64_t size) { return temp_storage.alloc(size); });
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
using func_ptr_t = unsigned long long (*)(void *);
auto func = reinterpret_cast<func_ptr_t>(buffer->data[0]);
results[id] = func(args[id]);
});
port->send([&](rpc::Buffer *buffer, uint32_t id) {
buffer->data[0] = static_cast<uint64_t>(results[id]);
});
port->send([&](rpc::Buffer *, uint32_t id) {});
break;
}
case RPC_FEOF: {
Expand Down
20 changes: 15 additions & 5 deletions offload/test/libc/host_call.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

#pragma omp begin declare variant match(device = {kind(gpu)})
// Extension provided by the 'libc' project.
void rpc_host_call(void *fn, void *args, size_t size);
unsigned long long rpc_host_call(void *fn, void *args, size_t size);
#pragma omp declare target to(rpc_host_call) device_type(nohost)
#pragma omp end declare variant

#pragma omp begin declare variant match(device = {kind(cpu)})
// Dummy host implementation to make this work for all targets.
void rpc_host_call(void *fn, void *args, size_t size) {
((void (*)(void *))fn)(args);
unsigned long long rpc_host_call(void *fn, void *args, size_t size) {
return ((unsigned long long (*)(void *))fn)(args);
}
#pragma omp end declare variant

Expand All @@ -25,17 +25,26 @@ typedef struct args_s {
} args_t;

// CHECK-DAG: Thread: 0, Block: 0
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 1, Block: 0
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 0, Block: 1
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 1, Block: 1
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 0, Block: 2
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 1, Block: 2
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 0, Block: 3
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 1, Block: 3
void foo(void *data) {
// CHECK-DAG: Result: 42
long long foo(void *data) {
assert(omp_is_initial_device() && "Not executing on host?");
args_t *args = (args_t *)data;
printf("Thread: %d, Block: %d\n", args->thread_id, args->block_id);
return 42;
}

void *fn_ptr = NULL;
Expand All @@ -49,6 +58,7 @@ int main() {
#pragma omp parallel num_threads(2)
{
args_t args = {omp_get_thread_num(), omp_get_team_num()};
rpc_host_call(fn_ptr, &args, sizeof(args_t));
unsigned long long res = rpc_host_call(fn_ptr, &args, sizeof(args_t));
printf("Result: %d\n", (int)res);
}
}

0 comments on commit 26ca8ef

Please sign in to comment.