Skip to content

Commit

Permalink
No longer silently hide errors in Metal completion handlers (alternat…
Browse files Browse the repository at this point in the history
…ive approach) (#8240)

* No longer silently hide errors in Metal completion handlers

* Actually implement alternative

* clang-format

* Implement new API

* Implement test and refine the API

* Format.

* Remove some debug code

* Add missing includes.

* Add comment noting why we manually null-terminate after strncpy

* Reverse engineer Objective-C API for passing void* in a block; it turns out to be much simpler than I thought

* Formatting

* Don't add const-ness to declaration.

---------

Co-authored-by: Steven Johnson <srj@google.com>
  • Loading branch information
shoaibkamil and steven-johnson authored Jun 14, 2024
1 parent 6c8a491 commit f9ccd5c
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 24 deletions.
12 changes: 12 additions & 0 deletions src/runtime/HalideRuntimeMetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ extern uint64_t halide_metal_get_crop_offset(void *user_context, struct halide_b

struct halide_metal_device;
struct halide_metal_command_queue;
struct halide_metal_command_buffer;

/** This prototype is exported as applications will typically need to
* replace it to get Halide filters to execute on the same device and
Expand All @@ -93,6 +94,17 @@ extern int halide_metal_acquire_context(void *user_context, struct halide_metal_
*/
extern int halide_metal_release_context(void *user_context);

/** This function is called as part of the callback when a Metal command buffer completes.
* The return value, if not halide_error_code_success, will be stashed in Metal runtime and returned
* to the next call into the runtime, and the error string will be saved as well.
* The error string will be freed by the caller. The return value must be a valid Halide error code.
* This is called from the Metal driver, and thus:
* - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback
* - The function must be thread-safe
*/
extern int halide_metal_command_buffer_completion_handler(void *user_context, struct halide_metal_command_buffer *buffer,
char **returned_error_string);

#ifdef __cplusplus
} // End extern "C"
#endif
Expand Down
175 changes: 152 additions & 23 deletions src/runtime/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
extern "C" {
extern objc_id MTLCreateSystemDefaultDevice();
extern struct ObjectiveCClass _NSConcreteGlobalBlock;
extern struct ObjectiveCClass _NSConcreteStackBlock;
void *dlsym(void *, const char *);
#define RTLD_DEFAULT ((void *)-2)
}
Expand All @@ -23,8 +24,8 @@ namespace Metal {

typedef halide_metal_device mtl_device;
typedef halide_metal_command_queue mtl_command_queue;
typedef halide_metal_command_buffer mtl_command_buffer;
struct mtl_buffer;
struct mtl_command_buffer;
struct mtl_compute_command_encoder;
struct mtl_blit_command_encoder;
struct mtl_compute_pipeline_state;
Expand Down Expand Up @@ -381,6 +382,8 @@ WEAK int halide_metal_release_context(void *user_context) {

} // extern "C"

extern "C" size_t strnlen(const char *s, size_t maxlen);

namespace Halide {
namespace Runtime {
namespace Internal {
Expand All @@ -389,7 +392,10 @@ namespace Metal {
class MetalContextHolder {
objc_id pool;
void *const user_context;
int status; // must always be a valid halide_error_code_t value
int status; // must always be a valid halide_error_code_t value
static int saved_status; // must always be a valid halide_error_code_t value
static halide_mutex saved_status_mutex; // mutex for accessing saved status
static char error_string[1024];

public:
mtl_device *device;
Expand All @@ -404,11 +410,128 @@ class MetalContextHolder {
drain_autorelease_pool(pool);
}

ALWAYS_INLINE int error() const {
return status;
// We use two variants of this function: one for just checking status, and one
// that returns and clears the previous status.
ALWAYS_INLINE static int get_and_clear_saved_status(char *error_string = nullptr) {
halide_mutex_lock(&saved_status_mutex);
int result = saved_status;
saved_status = halide_error_code_success;
if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) {
strncpy(error_string, MetalContextHolder::error_string, 1024);
// Ensure null-termination, since strncpy won't if the source string is too long
error_string[1023] = '\0';
MetalContextHolder::error_string[0] = '\0';
debug(nullptr) << "MetalContextHolder::get_and_clear_saved_status: " << error_string << "\n";
}
halide_mutex_unlock(&saved_status_mutex);
return result;
}

// Returns the previous status without clearing, and optionally copies the error string
ALWAYS_INLINE static int get_saved_status(char *error_string = nullptr) {
halide_mutex_lock(&saved_status_mutex);
int result = saved_status;
if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) {
strncpy(error_string, MetalContextHolder::error_string, 1024);
// Ensure null-termination, since strncpy won't if the source string is too long
error_string[1023] = '\0';
}
halide_mutex_unlock(&saved_status_mutex);
return result;
}

ALWAYS_INLINE static void set_saved_status(int new_status, char *error_string = nullptr) {
halide_mutex_lock(&saved_status_mutex);
saved_status = new_status;
if (error_string != nullptr) {
strncpy(MetalContextHolder::error_string, error_string, 1024);
// Ensure null-termination, since strncpy won't if the source string is too long
error_string[1023] = '\0';
debug(nullptr) << "MetalContextHolder::set_saved_status: " << error_string << "\n";
}
halide_mutex_unlock(&saved_status_mutex);
}

ALWAYS_INLINE int error(char *error_string = nullptr) const {
if (status != halide_error_code_success) {
return status;
} else {
return get_saved_status(error_string);
}
}

ALWAYS_INLINE int get_and_clear_error(char *error_string = nullptr) const {
auto cleared_status = get_and_clear_saved_status(error_string);
if (status != halide_error_code_success) {
return status;
} else {
return cleared_status;
}
}
};

int MetalContextHolder::saved_status = halide_error_code_success;
halide_mutex MetalContextHolder::saved_status_mutex = {0};
char MetalContextHolder::error_string[1024] = {0};

} // namespace Metal
} // namespace Internal
} // namespace Runtime
} // namespace Halide

extern "C" {
/** This function is called as part of the callback when a Metal command buffer completes.
* The return value, if not halide_error_code_success, will be stashed in Metal runtime and returned
* to the next call into the runtime, and the error string will be saved as well.
* The error string will be freed by the caller. The return value must be a valid Halide error code.
* This is called from the Metal driver, and thus:
* - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback
* - The function must be thread-safe
*/
WEAK int halide_metal_command_buffer_completion_handler(void *const user_context, mtl_command_buffer *buffer, char **returned_error_string) {
objc_id buffer_error = command_buffer_error(buffer);
if (buffer_error != nullptr) {
retain_ns_object(buffer_error);

ns_log_object(buffer_error);

// Obtain the localized NSString for the error
typedef objc_id (*localized_description_method_t)(objc_id objc, objc_sel sel);
localized_description_method_t localized_description_method = (localized_description_method_t)&objc_msgSend;
objc_id error_ns_string = (*localized_description_method)(buffer_error, sel_getUid("localizedDescription"));

retain_ns_object(error_ns_string);

// Obtain a C-style string
typedef char *(*utf8_string_method_t)(objc_id objc, objc_sel sel);
utf8_string_method_t utf8_string_method = (utf8_string_method_t)&objc_msgSend;
char *error_string = (*utf8_string_method)(error_ns_string, sel_getUid("UTF8String"));

// Copy C-style string into a fresh buffer
if (returned_error_string != nullptr) {
*returned_error_string = (char *)malloc(sizeof(char) * 1024);
if (*returned_error_string != nullptr) {
strncpy(*returned_error_string, error_string, 1024);
// Ensure null-termination, since strncpy won't if the source string is too long
(*returned_error_string)[1023] = '\0';
} else {
debug(user_context) << "halide_metal_command_buffer_completion_handler: Failed to allocate memory for error string.\n";
}
}

release_ns_object(error_ns_string);
release_ns_object(buffer_error);
return halide_error_code_device_run_failed;
}
return halide_error_code_success;
}
} // extern "C"

namespace Halide {
namespace Runtime {
namespace Internal {
namespace Metal {

struct command_buffer_completed_handler_block_descriptor_1 {
unsigned long reserved;
unsigned long block_size;
Expand All @@ -420,24 +543,23 @@ struct command_buffer_completed_handler_block_literal {
int reserved;
void (*invoke)(command_buffer_completed_handler_block_literal *, mtl_command_buffer *buffer);
struct command_buffer_completed_handler_block_descriptor_1 *descriptor;
void *const user_context;
};

WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_completed_handler_descriptor = {
0, sizeof(command_buffer_completed_handler_block_literal)};

WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) {
objc_id buffer_error = command_buffer_error(buffer);
if (buffer_error != nullptr) {
ns_log_object(buffer_error);
release_ns_object(buffer_error);
}
}
retain_ns_object(buffer);
char *error_string = nullptr;
void *const user_context = block->user_context;

WEAK command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = {
&_NSConcreteGlobalBlock,
(1 << 28) | (1 << 29), // BLOCK_IS_GLOBAL | BLOCK_HAS_DESCRIPTOR
0, command_buffer_completed_handler_invoke,
&command_buffer_completed_handler_descriptor};
auto status = halide_metal_command_buffer_completion_handler(user_context, buffer, &error_string);
release_ns_object(buffer);

MetalContextHolder::set_saved_status(status, error_string);
free(error_string);
}

} // namespace Metal
} // namespace Internal
Expand Down Expand Up @@ -476,7 +598,7 @@ WEAK int halide_metal_device_malloc(void *user_context, halide_buffer_t *buf) {

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

#ifdef DEBUG_RUNTIME
Expand Down Expand Up @@ -544,7 +666,7 @@ WEAK int halide_metal_device_free(void *user_context, halide_buffer_t *buf) {
WEAK int halide_metal_initialize_kernels(void *user_context, void **state_ptr, const char *source, int source_size) {
MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}
#ifdef DEBUG_RUNTIME
uint64_t t_before = halide_current_time_ns(user_context);
Expand Down Expand Up @@ -600,7 +722,7 @@ WEAK int halide_metal_device_sync(void *user_context, struct halide_buffer_t *bu

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

halide_metal_device_sync_internal(metal_context.queue, buffer);
Expand Down Expand Up @@ -651,7 +773,7 @@ WEAK int halide_metal_copy_to_device(void *user_context, halide_buffer_t *buffer

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

if (!(buffer->host && buffer->device)) {
Expand Down Expand Up @@ -695,7 +817,7 @@ WEAK int halide_metal_copy_to_host(void *user_context, halide_buffer_t *buffer)

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

halide_metal_device_sync_internal(metal_context.queue, buffer);
Expand Down Expand Up @@ -738,7 +860,7 @@ WEAK int halide_metal_run(void *user_context,

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

mtl_command_buffer *command_buffer = new_command_buffer(metal_context.queue, entry_name, strlen(entry_name));
Expand Down Expand Up @@ -882,6 +1004,13 @@ WEAK int halide_metal_run(void *user_context,
threadsX, threadsY, threadsZ);
end_encoding(encoder);

command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = {
&_NSConcreteStackBlock,
0, // must be 0 for stack blocks
0, command_buffer_completed_handler_invoke,
&command_buffer_completed_handler_descriptor,
user_context};

add_command_buffer_completed_handler(command_buffer, &command_buffer_completed_handler_block);

commit_command_buffer(command_buffer);
Expand Down Expand Up @@ -962,7 +1091,7 @@ WEAK int halide_metal_buffer_copy(void *user_context, struct halide_buffer_t *sr
{
MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

debug(user_context)
Expand Down Expand Up @@ -1036,7 +1165,7 @@ WEAK int metal_device_crop_from_offset(void *user_context,
struct halide_buffer_t *dst) {
MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

dst->device_interface = src->device_interface;
Expand Down
3 changes: 2 additions & 1 deletion test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,13 @@ tests(GROUPS correctness
gpu_data_flows.cpp
gpu_different_blocks_threads_dimensions.cpp
gpu_dynamic_shared.cpp
gpu_f16_intrinsics.cpp
gpu_free_sync.cpp
gpu_give_input_buffers_device_allocations.cpp
gpu_jit_explicit_copy_to_device.cpp
gpu_large_alloc.cpp
gpu_many_kernels.cpp
gpu_f16_intrinsics.cpp
gpu_metal_completion_handler_error_check.cpp
gpu_mixed_dimensionality.cpp
gpu_mixed_shared_mem_types.cpp
gpu_multi_kernel.cpp
Expand Down
45 changes: 45 additions & 0 deletions test/correctness/gpu_metal_completion_handler_error_check.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

bool errored = false;

int main(int argc, char **argv) {
Target t = get_jit_target_from_environment();
if (!t.has_feature(Target::Metal)) {
printf("[SKIP] Metal not enabled\n");
return 0;
}

Func f, g;
Var c, x, ci, xi;
RVar rxi;
RDom r(0, 1000, -327600, 327600);

// Create a function that is very costly to execute, resulting in a timeout
// on the GPU
f(x, c) = x + 0.1f * c;
f(r.x, c) += cos(sin(tan(cosh(tanh(sinh(exp(tanh(exp(log(tan(cos(exp(f(r.x, c) / cos(cosh(sinh(sin(f(r.x, c))))) / tanh(tan(tan(f(r.x, c)))))))))) + cast<float>(cast<uint8_t>(f(r.x, c) / cast<uint8_t>(log(f(r.x, c))))))))))));

f.gpu_tile(x, c, xi, ci, 4, 4);
f.update(0).gpu_tile(r.x, c, rxi, ci, 4, 4);

// Metal is surprisingly resilient. Run this in a loop just to make sure we trigger the error.
for (int i = 0; (i < 10) && !errored; i++) {
auto out = f.realize({1000, 100}, t);
int result = out.device_sync();
if (result != halide_error_code_success) {
printf("Device sync failed as expected: %d\n", result);
errored = true;
}
}

if (!errored) {
printf("There was supposed to be an error\n");
return 1;
}

printf("Success!\n");
return 0;
}
5 changes: 5 additions & 0 deletions test/generator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,11 @@ _add_halide_libraries(metadata_tester_ucon
_add_halide_aot_tests(metadata_tester
HALIDE_LIBRARIES metadata_tester metadata_tester_ucon)

# metal_completion_handler_override_aottest.cpp
# metal_completion_handler_override_generator.cpp
_add_halide_libraries(metal_completion_handler_override FEATURES user_context)
_add_halide_aot_tests(metal_completion_handler_override)

# msan_aottest.cpp
# msan_generator.cpp
if ("${Halide_TARGET}" MATCHES "webgpu")
Expand Down
Loading

0 comments on commit f9ccd5c

Please sign in to comment.