diff --git a/libc/newhdrgen/yaml/gpu/rpc.yaml b/libc/newhdrgen/yaml/gpu/rpc.yaml index 61856bc0c7d692..da4f6afb7856d2 100644 --- a/libc/newhdrgen/yaml/gpu/rpc.yaml +++ b/libc/newhdrgen/yaml/gpu/rpc.yaml @@ -16,7 +16,7 @@ functions: - name: rpc_host_call standards: - GPUExtensions - return_type: void + return_type: unsigned long long arguments: - type: void * - type: void * diff --git a/libc/spec/gpu_ext.td b/libc/spec/gpu_ext.td index dce81ff7786203..d99531dc06bcd6 100644 --- a/libc/spec/gpu_ext.td +++ b/libc/spec/gpu_ext.td @@ -7,7 +7,7 @@ def GPUExtensions : StandardSpec<"GPUExtensions"> { [ FunctionSpec< "rpc_host_call", - RetValSpec, + RetValSpec, [ArgSpec, ArgSpec, ArgSpec] >, ] diff --git a/libc/src/gpu/rpc_host_call.cpp b/libc/src/gpu/rpc_host_call.cpp index ca2e331340a6cb..f21fadc319c615 100644 --- a/libc/src/gpu/rpc_host_call.cpp +++ b/libc/src/gpu/rpc_host_call.cpp @@ -17,14 +17,19 @@ namespace LIBC_NAMESPACE_DECL { // This calls the associated function pointer on the RPC server with the given // arguments. We expect that the pointer here is a valid pointer on the server. -LLVM_LIBC_FUNCTION(void, rpc_host_call, (void *fn, void *data, size_t size)) { +LLVM_LIBC_FUNCTION(unsigned long long, rpc_host_call, + (void *fn, void *data, size_t size)) { rpc::Client::Port port = rpc::client.open(); port.send_n(data, size); port.send([=](rpc::Buffer *buffer) { buffer->data[0] = reinterpret_cast(fn); }); - port.recv([](rpc::Buffer *) {}); + unsigned long long ret; + port.recv([&](rpc::Buffer *buffer) { + ret = static_cast(buffer->data[0]); + }); port.close(); + return ret; } } // namespace LIBC_NAMESPACE_DECL diff --git a/libc/src/gpu/rpc_host_call.h b/libc/src/gpu/rpc_host_call.h index 7cfea757ccdfd1..861149dead561e 100644 --- a/libc/src/gpu/rpc_host_call.h +++ b/libc/src/gpu/rpc_host_call.h @@ -14,7 +14,7 @@ namespace LIBC_NAMESPACE_DECL { -void rpc_host_call(void *fn, void *buffer, size_t size); +unsigned long long rpc_host_call(void *fn, void *buffer, size_t size); } // namespace LIBC_NAMESPACE_DECL diff --git a/libc/utils/gpu/server/rpc_server.cpp b/libc/utils/gpu/server/rpc_server.cpp index 6951c5ae147df7..ca10e67509ae63 100644 --- a/libc/utils/gpu/server/rpc_server.cpp +++ b/libc/utils/gpu/server/rpc_server.cpp @@ -319,13 +319,18 @@ rpc_status_t handle_server_impl( } case RPC_HOST_CALL: { uint64_t sizes[lane_size] = {0}; + unsigned long long results[lane_size] = {0}; void *args[lane_size] = {nullptr}; port->recv_n(args, sizes, [&](uint64_t size) { return temp_storage.alloc(size); }); port->recv([&](rpc::Buffer *buffer, uint32_t id) { - reinterpret_cast(buffer->data[0])(args[id]); + using func_ptr_t = unsigned long long (*)(void *); + auto func = reinterpret_cast(buffer->data[0]); + results[id] = func(args[id]); + }); + port->send([&](rpc::Buffer *buffer, uint32_t id) { + buffer->data[0] = static_cast(results[id]); }); - port->send([&](rpc::Buffer *, uint32_t id) {}); break; } case RPC_FEOF: { diff --git a/offload/test/libc/host_call.c b/offload/test/libc/host_call.c index 11260cc285765d..61c4e14d5b3881 100644 --- a/offload/test/libc/host_call.c +++ b/offload/test/libc/host_call.c @@ -8,14 +8,14 @@ #pragma omp begin declare variant match(device = {kind(gpu)}) // Extension provided by the 'libc' project. -void rpc_host_call(void *fn, void *args, size_t size); +unsigned long long rpc_host_call(void *fn, void *args, size_t size); #pragma omp declare target to(rpc_host_call) device_type(nohost) #pragma omp end declare variant #pragma omp begin declare variant match(device = {kind(cpu)}) // Dummy host implementation to make this work for all targets. -void rpc_host_call(void *fn, void *args, size_t size) { - ((void (*)(void *))fn)(args); +unsigned long long rpc_host_call(void *fn, void *args, size_t size) { + return ((unsigned long long (*)(void *))fn)(args); } #pragma omp end declare variant @@ -25,17 +25,26 @@ typedef struct args_s { } args_t; // CHECK-DAG: Thread: 0, Block: 0 +// CHECK-DAG: Result: 42 // CHECK-DAG: Thread: 1, Block: 0 +// CHECK-DAG: Result: 42 // CHECK-DAG: Thread: 0, Block: 1 +// CHECK-DAG: Result: 42 // CHECK-DAG: Thread: 1, Block: 1 +// CHECK-DAG: Result: 42 // CHECK-DAG: Thread: 0, Block: 2 +// CHECK-DAG: Result: 42 // CHECK-DAG: Thread: 1, Block: 2 +// CHECK-DAG: Result: 42 // CHECK-DAG: Thread: 0, Block: 3 +// CHECK-DAG: Result: 42 // CHECK-DAG: Thread: 1, Block: 3 -void foo(void *data) { +// CHECK-DAG: Result: 42 +long long foo(void *data) { assert(omp_is_initial_device() && "Not executing on host?"); args_t *args = (args_t *)data; printf("Thread: %d, Block: %d\n", args->thread_id, args->block_id); + return 42; } void *fn_ptr = NULL; @@ -49,6 +58,7 @@ int main() { #pragma omp parallel num_threads(2) { args_t args = {omp_get_thread_num(), omp_get_team_num()}; - rpc_host_call(fn_ptr, &args, sizeof(args_t)); + unsigned long long res = rpc_host_call(fn_ptr, &args, sizeof(args_t)); + printf("Result: %d\n", (int)res); } }