From f9ccd5c0ae74c6c260e9be963dc0ab0410d93bdb Mon Sep 17 00:00:00 2001 From: Shoaib Kamil Date: Fri, 14 Jun 2024 14:24:17 -0400 Subject: [PATCH] No longer silently hide errors in Metal completion handlers (alternative approach) (#8240) * No longer silently hide errors in Metal completion handlers * Actually implement alternative * clang-format * Implement new API * Implement test and refine the API * Format. * Remove some debug code * Add missing includes. * Add comment noting why we manually null-terminate after strncpy * Reverse engineer Objective-C API for passing void* in a block; it turns out to be much simpler than I thought * Formatting * Don't add const-ness to declaration. --------- Co-authored-by: Steven Johnson --- src/runtime/HalideRuntimeMetal.h | 12 ++ src/runtime/metal.cpp | 175 +++++++++++++++--- test/correctness/CMakeLists.txt | 3 +- ...u_metal_completion_handler_error_check.cpp | 45 +++++ test/generator/CMakeLists.txt | 5 + ...al_completion_handler_override_aottest.cpp | 55 ++++++ ..._completion_handler_override_generator.cpp | 25 +++ 7 files changed, 296 insertions(+), 24 deletions(-) create mode 100644 test/correctness/gpu_metal_completion_handler_error_check.cpp create mode 100644 test/generator/metal_completion_handler_override_aottest.cpp create mode 100644 test/generator/metal_completion_handler_override_generator.cpp diff --git a/src/runtime/HalideRuntimeMetal.h b/src/runtime/HalideRuntimeMetal.h index 8fd0f364cebb..30762e07d8ae 100644 --- a/src/runtime/HalideRuntimeMetal.h +++ b/src/runtime/HalideRuntimeMetal.h @@ -68,6 +68,7 @@ extern uint64_t halide_metal_get_crop_offset(void *user_context, struct halide_b struct halide_metal_device; struct halide_metal_command_queue; +struct halide_metal_command_buffer; /** This prototype is exported as applications will typically need to * replace it to get Halide filters to execute on the same device and @@ -93,6 +94,17 @@ extern int halide_metal_acquire_context(void *user_context, struct halide_metal_ */ extern int halide_metal_release_context(void *user_context); +/** This function is called as part of the callback when a Metal command buffer completes. + * The return value, if not halide_error_code_success, will be stashed in Metal runtime and returned + * to the next call into the runtime, and the error string will be saved as well. + * The error string will be freed by the caller. The return value must be a valid Halide error code. + * This is called from the Metal driver, and thus: + * - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback + * - The function must be thread-safe + */ +extern int halide_metal_command_buffer_completion_handler(void *user_context, struct halide_metal_command_buffer *buffer, + char **returned_error_string); + #ifdef __cplusplus } // End extern "C" #endif diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index abc935b0743e..1fe7d895561b 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -12,6 +12,7 @@ extern "C" { extern objc_id MTLCreateSystemDefaultDevice(); extern struct ObjectiveCClass _NSConcreteGlobalBlock; +extern struct ObjectiveCClass _NSConcreteStackBlock; void *dlsym(void *, const char *); #define RTLD_DEFAULT ((void *)-2) } @@ -23,8 +24,8 @@ namespace Metal { typedef halide_metal_device mtl_device; typedef halide_metal_command_queue mtl_command_queue; +typedef halide_metal_command_buffer mtl_command_buffer; struct mtl_buffer; -struct mtl_command_buffer; struct mtl_compute_command_encoder; struct mtl_blit_command_encoder; struct mtl_compute_pipeline_state; @@ -381,6 +382,8 @@ WEAK int halide_metal_release_context(void *user_context) { } // extern "C" +extern "C" size_t strnlen(const char *s, size_t maxlen); + namespace Halide { namespace Runtime { namespace Internal { @@ -389,7 +392,10 @@ namespace Metal { class MetalContextHolder { objc_id pool; void *const user_context; - int status; // must always be a valid halide_error_code_t value + int status; // must always be a valid halide_error_code_t value + static int saved_status; // must always be a valid halide_error_code_t value + static halide_mutex saved_status_mutex; // mutex for accessing saved status + static char error_string[1024]; public: mtl_device *device; @@ -404,11 +410,128 @@ class MetalContextHolder { drain_autorelease_pool(pool); } - ALWAYS_INLINE int error() const { - return status; + // We use two variants of this function: one for just checking status, and one + // that returns and clears the previous status. + ALWAYS_INLINE static int get_and_clear_saved_status(char *error_string = nullptr) { + halide_mutex_lock(&saved_status_mutex); + int result = saved_status; + saved_status = halide_error_code_success; + if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) { + strncpy(error_string, MetalContextHolder::error_string, 1024); + // Ensure null-termination, since strncpy won't if the source string is too long + error_string[1023] = '\0'; + MetalContextHolder::error_string[0] = '\0'; + debug(nullptr) << "MetalContextHolder::get_and_clear_saved_status: " << error_string << "\n"; + } + halide_mutex_unlock(&saved_status_mutex); + return result; + } + + // Returns the previous status without clearing, and optionally copies the error string + ALWAYS_INLINE static int get_saved_status(char *error_string = nullptr) { + halide_mutex_lock(&saved_status_mutex); + int result = saved_status; + if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) { + strncpy(error_string, MetalContextHolder::error_string, 1024); + // Ensure null-termination, since strncpy won't if the source string is too long + error_string[1023] = '\0'; + } + halide_mutex_unlock(&saved_status_mutex); + return result; + } + + ALWAYS_INLINE static void set_saved_status(int new_status, char *error_string = nullptr) { + halide_mutex_lock(&saved_status_mutex); + saved_status = new_status; + if (error_string != nullptr) { + strncpy(MetalContextHolder::error_string, error_string, 1024); + // Ensure null-termination, since strncpy won't if the source string is too long + error_string[1023] = '\0'; + debug(nullptr) << "MetalContextHolder::set_saved_status: " << error_string << "\n"; + } + halide_mutex_unlock(&saved_status_mutex); + } + + ALWAYS_INLINE int error(char *error_string = nullptr) const { + if (status != halide_error_code_success) { + return status; + } else { + return get_saved_status(error_string); + } + } + + ALWAYS_INLINE int get_and_clear_error(char *error_string = nullptr) const { + auto cleared_status = get_and_clear_saved_status(error_string); + if (status != halide_error_code_success) { + return status; + } else { + return cleared_status; + } } }; +int MetalContextHolder::saved_status = halide_error_code_success; +halide_mutex MetalContextHolder::saved_status_mutex = {0}; +char MetalContextHolder::error_string[1024] = {0}; + +} // namespace Metal +} // namespace Internal +} // namespace Runtime +} // namespace Halide + +extern "C" { +/** This function is called as part of the callback when a Metal command buffer completes. + * The return value, if not halide_error_code_success, will be stashed in Metal runtime and returned + * to the next call into the runtime, and the error string will be saved as well. + * The error string will be freed by the caller. The return value must be a valid Halide error code. + * This is called from the Metal driver, and thus: + * - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback + * - The function must be thread-safe + */ +WEAK int halide_metal_command_buffer_completion_handler(void *const user_context, mtl_command_buffer *buffer, char **returned_error_string) { + objc_id buffer_error = command_buffer_error(buffer); + if (buffer_error != nullptr) { + retain_ns_object(buffer_error); + + ns_log_object(buffer_error); + + // Obtain the localized NSString for the error + typedef objc_id (*localized_description_method_t)(objc_id objc, objc_sel sel); + localized_description_method_t localized_description_method = (localized_description_method_t)&objc_msgSend; + objc_id error_ns_string = (*localized_description_method)(buffer_error, sel_getUid("localizedDescription")); + + retain_ns_object(error_ns_string); + + // Obtain a C-style string + typedef char *(*utf8_string_method_t)(objc_id objc, objc_sel sel); + utf8_string_method_t utf8_string_method = (utf8_string_method_t)&objc_msgSend; + char *error_string = (*utf8_string_method)(error_ns_string, sel_getUid("UTF8String")); + + // Copy C-style string into a fresh buffer + if (returned_error_string != nullptr) { + *returned_error_string = (char *)malloc(sizeof(char) * 1024); + if (*returned_error_string != nullptr) { + strncpy(*returned_error_string, error_string, 1024); + // Ensure null-termination, since strncpy won't if the source string is too long + (*returned_error_string)[1023] = '\0'; + } else { + debug(user_context) << "halide_metal_command_buffer_completion_handler: Failed to allocate memory for error string.\n"; + } + } + + release_ns_object(error_ns_string); + release_ns_object(buffer_error); + return halide_error_code_device_run_failed; + } + return halide_error_code_success; +} +} // extern "C" + +namespace Halide { +namespace Runtime { +namespace Internal { +namespace Metal { + struct command_buffer_completed_handler_block_descriptor_1 { unsigned long reserved; unsigned long block_size; @@ -420,24 +543,23 @@ struct command_buffer_completed_handler_block_literal { int reserved; void (*invoke)(command_buffer_completed_handler_block_literal *, mtl_command_buffer *buffer); struct command_buffer_completed_handler_block_descriptor_1 *descriptor; + void *const user_context; }; WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_completed_handler_descriptor = { 0, sizeof(command_buffer_completed_handler_block_literal)}; WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) { - objc_id buffer_error = command_buffer_error(buffer); - if (buffer_error != nullptr) { - ns_log_object(buffer_error); - release_ns_object(buffer_error); - } -} + retain_ns_object(buffer); + char *error_string = nullptr; + void *const user_context = block->user_context; -WEAK command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = { - &_NSConcreteGlobalBlock, - (1 << 28) | (1 << 29), // BLOCK_IS_GLOBAL | BLOCK_HAS_DESCRIPTOR - 0, command_buffer_completed_handler_invoke, - &command_buffer_completed_handler_descriptor}; + auto status = halide_metal_command_buffer_completion_handler(user_context, buffer, &error_string); + release_ns_object(buffer); + + MetalContextHolder::set_saved_status(status, error_string); + free(error_string); +} } // namespace Metal } // namespace Internal @@ -476,7 +598,7 @@ WEAK int halide_metal_device_malloc(void *user_context, halide_buffer_t *buf) { MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } #ifdef DEBUG_RUNTIME @@ -544,7 +666,7 @@ WEAK int halide_metal_device_free(void *user_context, halide_buffer_t *buf) { WEAK int halide_metal_initialize_kernels(void *user_context, void **state_ptr, const char *source, int source_size) { MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } #ifdef DEBUG_RUNTIME uint64_t t_before = halide_current_time_ns(user_context); @@ -600,7 +722,7 @@ WEAK int halide_metal_device_sync(void *user_context, struct halide_buffer_t *bu MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } halide_metal_device_sync_internal(metal_context.queue, buffer); @@ -651,7 +773,7 @@ WEAK int halide_metal_copy_to_device(void *user_context, halide_buffer_t *buffer MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } if (!(buffer->host && buffer->device)) { @@ -695,7 +817,7 @@ WEAK int halide_metal_copy_to_host(void *user_context, halide_buffer_t *buffer) MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } halide_metal_device_sync_internal(metal_context.queue, buffer); @@ -738,7 +860,7 @@ WEAK int halide_metal_run(void *user_context, MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } mtl_command_buffer *command_buffer = new_command_buffer(metal_context.queue, entry_name, strlen(entry_name)); @@ -882,6 +1004,13 @@ WEAK int halide_metal_run(void *user_context, threadsX, threadsY, threadsZ); end_encoding(encoder); + command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = { + &_NSConcreteStackBlock, + 0, // must be 0 for stack blocks + 0, command_buffer_completed_handler_invoke, + &command_buffer_completed_handler_descriptor, + user_context}; + add_command_buffer_completed_handler(command_buffer, &command_buffer_completed_handler_block); commit_command_buffer(command_buffer); @@ -962,7 +1091,7 @@ WEAK int halide_metal_buffer_copy(void *user_context, struct halide_buffer_t *sr { MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } debug(user_context) @@ -1036,7 +1165,7 @@ WEAK int metal_device_crop_from_offset(void *user_context, struct halide_buffer_t *dst) { MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { - return metal_context.error(); + return metal_context.get_and_clear_error(); } dst->device_interface = src->device_interface; diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 8e4d9ecce3ff..623fdd16da5b 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -134,12 +134,13 @@ tests(GROUPS correctness gpu_data_flows.cpp gpu_different_blocks_threads_dimensions.cpp gpu_dynamic_shared.cpp + gpu_f16_intrinsics.cpp gpu_free_sync.cpp gpu_give_input_buffers_device_allocations.cpp gpu_jit_explicit_copy_to_device.cpp gpu_large_alloc.cpp gpu_many_kernels.cpp - gpu_f16_intrinsics.cpp + gpu_metal_completion_handler_error_check.cpp gpu_mixed_dimensionality.cpp gpu_mixed_shared_mem_types.cpp gpu_multi_kernel.cpp diff --git a/test/correctness/gpu_metal_completion_handler_error_check.cpp b/test/correctness/gpu_metal_completion_handler_error_check.cpp new file mode 100644 index 000000000000..f0bb396e2c12 --- /dev/null +++ b/test/correctness/gpu_metal_completion_handler_error_check.cpp @@ -0,0 +1,45 @@ +#include "Halide.h" +#include + +using namespace Halide; + +bool errored = false; + +int main(int argc, char **argv) { + Target t = get_jit_target_from_environment(); + if (!t.has_feature(Target::Metal)) { + printf("[SKIP] Metal not enabled\n"); + return 0; + } + + Func f, g; + Var c, x, ci, xi; + RVar rxi; + RDom r(0, 1000, -327600, 327600); + + // Create a function that is very costly to execute, resulting in a timeout + // on the GPU + f(x, c) = x + 0.1f * c; + f(r.x, c) += cos(sin(tan(cosh(tanh(sinh(exp(tanh(exp(log(tan(cos(exp(f(r.x, c) / cos(cosh(sinh(sin(f(r.x, c))))) / tanh(tan(tan(f(r.x, c)))))))))) + cast(cast(f(r.x, c) / cast(log(f(r.x, c)))))))))))); + + f.gpu_tile(x, c, xi, ci, 4, 4); + f.update(0).gpu_tile(r.x, c, rxi, ci, 4, 4); + + // Metal is surprisingly resilient. Run this in a loop just to make sure we trigger the error. + for (int i = 0; (i < 10) && !errored; i++) { + auto out = f.realize({1000, 100}, t); + int result = out.device_sync(); + if (result != halide_error_code_success) { + printf("Device sync failed as expected: %d\n", result); + errored = true; + } + } + + if (!errored) { + printf("There was supposed to be an error\n"); + return 1; + } + + printf("Success!\n"); + return 0; +} diff --git a/test/generator/CMakeLists.txt b/test/generator/CMakeLists.txt index fc1cbfc76e78..2c010ae07717 100644 --- a/test/generator/CMakeLists.txt +++ b/test/generator/CMakeLists.txt @@ -497,6 +497,11 @@ _add_halide_libraries(metadata_tester_ucon _add_halide_aot_tests(metadata_tester HALIDE_LIBRARIES metadata_tester metadata_tester_ucon) +# metal_completion_handler_override_aottest.cpp +# metal_completion_handler_override_generator.cpp +_add_halide_libraries(metal_completion_handler_override FEATURES user_context) +_add_halide_aot_tests(metal_completion_handler_override) + # msan_aottest.cpp # msan_generator.cpp if ("${Halide_TARGET}" MATCHES "webgpu") diff --git a/test/generator/metal_completion_handler_override_aottest.cpp b/test/generator/metal_completion_handler_override_aottest.cpp new file mode 100644 index 000000000000..d6cd7fb5a112 --- /dev/null +++ b/test/generator/metal_completion_handler_override_aottest.cpp @@ -0,0 +1,55 @@ +#include + +#include "HalideBuffer.h" +#include "HalideRuntime.h" +#include "HalideRuntimeMetal.h" + +#include "metal_completion_handler_override.h" + +struct MyUserContext { + int counter; + + MyUserContext() + : counter(0) { + } +}; + +extern "C" int halide_metal_command_buffer_completion_handler(void *const user_context, struct halide_metal_command_buffer *, char **) { + if (user_context == nullptr) { + printf("Error: user_context is nullptr\n"); + return -1; + } + auto ctx = (MyUserContext *)user_context; + ctx->counter++; + return halide_error_code_success; +} + +int main(int argc, char *argv[]) { +#if defined(TEST_METAL) + Halide::Runtime::Buffer output(32, 32); + + MyUserContext my_context; + metal_completion_handler_override(&my_context, output); + output.copy_to_host(); + + // Check the output + for (int y = 0; y < output.height(); y++) { + for (int x = 0; x < output.width(); x++) { + if (output(x, y) != x + y * 2) { + printf("Error: output(%d, %d) = %d instead of %d\n", x, y, output(x, y), x + y * 2); + return -1; + } + } + } + + if (my_context.counter < 1) { + printf("Error: completion handler was not called\n"); + return -1; + } + + printf("Success!\n"); +#else + printf("[SKIP] Metal not enabled\n"); +#endif + return 0; +} \ No newline at end of file diff --git a/test/generator/metal_completion_handler_override_generator.cpp b/test/generator/metal_completion_handler_override_generator.cpp new file mode 100644 index 000000000000..8130a87710dc --- /dev/null +++ b/test/generator/metal_completion_handler_override_generator.cpp @@ -0,0 +1,25 @@ +#include "Halide.h" + +namespace { + +class SimpleMetalPipeline : public Halide::Generator { +public: + Output> output{"output"}; + + void generate() { + Var x("x"), y("y"); + + // Create a simple pipeline that scales pixel values by 2. + output(x, y) = x + y * 2; + + Target target = get_target(); + if (target.has_gpu_feature()) { + Var xo, yo, xi, yi; + output.gpu_tile(x, y, xo, yo, xi, yi, 16, 16); + } + } +}; + +} // namespace + +HALIDE_REGISTER_GENERATOR(SimpleMetalPipeline, metal_completion_handler_override) \ No newline at end of file