Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[libc] GPU RPC interface: add return value to rpc_host_call #111288

Merged
merged 3 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libc/newhdrgen/yaml/gpu/rpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ functions:
- name: rpc_host_call
standards:
- GPUExtensions
return_type: void
return_type: unsigned long long
arguments:
- type: void *
- type: void *
Expand Down
2 changes: 1 addition & 1 deletion libc/spec/gpu_ext.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def GPUExtensions : StandardSpec<"GPUExtensions"> {
[
FunctionSpec<
"rpc_host_call",
RetValSpec<VoidType>,
RetValSpec<UnsignedLongLongType>,
[ArgSpec<VoidPtr>, ArgSpec<VoidPtr>, ArgSpec<SizeTType>]
>,
]
Expand Down
9 changes: 7 additions & 2 deletions libc/src/gpu/rpc_host_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@ namespace LIBC_NAMESPACE_DECL {

// This calls the associated function pointer on the RPC server with the given
// arguments. We expect that the pointer here is a valid pointer on the server.
LLVM_LIBC_FUNCTION(void, rpc_host_call, (void *fn, void *data, size_t size)) {
LLVM_LIBC_FUNCTION(unsigned long long, rpc_host_call,
(void *fn, void *data, size_t size)) {
rpc::Client::Port port = rpc::client.open<RPC_HOST_CALL>();
port.send_n(data, size);
port.send([=](rpc::Buffer *buffer) {
buffer->data[0] = reinterpret_cast<uintptr_t>(fn);
});
port.recv([](rpc::Buffer *) {});
unsigned long long ret;
port.recv([&](rpc::Buffer *buffer) {
ret = static_cast<unsigned long long>(buffer->data[0]);
});
port.close();
return ret;
}

} // namespace LIBC_NAMESPACE_DECL
2 changes: 1 addition & 1 deletion libc/src/gpu/rpc_host_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

namespace LIBC_NAMESPACE_DECL {

void rpc_host_call(void *fn, void *buffer, size_t size);
unsigned long long rpc_host_call(void *fn, void *buffer, size_t size);

} // namespace LIBC_NAMESPACE_DECL

Expand Down
9 changes: 7 additions & 2 deletions libc/utils/gpu/server/rpc_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,18 @@ rpc_status_t handle_server_impl(
}
case RPC_HOST_CALL: {
uint64_t sizes[lane_size] = {0};
unsigned long long results[lane_size] = {0};
void *args[lane_size] = {nullptr};
port->recv_n(args, sizes,
[&](uint64_t size) { return temp_storage.alloc(size); });
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
using func_ptr_t = unsigned long long (*)(void *);
auto func = reinterpret_cast<func_ptr_t>(buffer->data[0]);
results[id] = func(args[id]);
});
port->send([&](rpc::Buffer *buffer, uint32_t id) {
buffer->data[0] = static_cast<uint64_t>(results[id]);
});
port->send([&](rpc::Buffer *, uint32_t id) {});
break;
}
case RPC_FEOF: {
Expand Down
20 changes: 15 additions & 5 deletions offload/test/libc/host_call.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

#pragma omp begin declare variant match(device = {kind(gpu)})
// Extension provided by the 'libc' project.
void rpc_host_call(void *fn, void *args, size_t size);
unsigned long long rpc_host_call(void *fn, void *args, size_t size);
#pragma omp declare target to(rpc_host_call) device_type(nohost)
#pragma omp end declare variant

#pragma omp begin declare variant match(device = {kind(cpu)})
// Dummy host implementation to make this work for all targets.
void rpc_host_call(void *fn, void *args, size_t size) {
((void (*)(void *))fn)(args);
unsigned long long rpc_host_call(void *fn, void *args, size_t size) {
return ((unsigned long long (*)(void *))fn)(args);
}
#pragma omp end declare variant

Expand All @@ -25,17 +25,26 @@ typedef struct args_s {
} args_t;

// CHECK-DAG: Thread: 0, Block: 0
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 1, Block: 0
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 0, Block: 1
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 1, Block: 1
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 0, Block: 2
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 1, Block: 2
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 0, Block: 3
// CHECK-DAG: Result: 42
// CHECK-DAG: Thread: 1, Block: 3
void foo(void *data) {
// CHECK-DAG: Result: 42
long long foo(void *data) {
assert(omp_is_initial_device() && "Not executing on host?");
args_t *args = (args_t *)data;
printf("Thread: %d, Block: %d\n", args->thread_id, args->block_id);
return 42;
}

void *fn_ptr = NULL;
Expand All @@ -49,6 +58,7 @@ int main() {
#pragma omp parallel num_threads(2)
{
args_t args = {omp_get_thread_num(), omp_get_team_num()};
rpc_host_call(fn_ptr, &args, sizeof(args_t));
unsigned long long res = rpc_host_call(fn_ptr, &args, sizeof(args_t));
printf("Result: %d\n", (int)res);
}
}
Loading