Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No longer silently hide errors in Metal completion handlers (alternative approach) #8240

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 76 additions & 11 deletions src/runtime/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ WEAK int halide_metal_release_context(void *user_context) {

} // extern "C"

extern "C" size_t strnlen(const char *s, size_t maxlen);

namespace Halide {
namespace Runtime {
namespace Internal {
Expand All @@ -389,7 +391,10 @@ namespace Metal {
class MetalContextHolder {
objc_id pool;
void *const user_context;
int status; // must always be a valid halide_error_code_t value
int status; // must always be a valid halide_error_code_t value
static int saved_status; // must always be a valid halide_error_code_t value
static halide_mutex saved_status_mutex; // mutex for accessing saved status
static char error_string[1024];

public:
mtl_device *device;
Expand All @@ -404,11 +409,55 @@ class MetalContextHolder {
drain_autorelease_pool(pool);
}

ALWAYS_INLINE int error() const {
return status;
// We use two variants of this function: one for just checking status, and one
// that returns and clears the previous status.
ALWAYS_INLINE static int get_and_clear_saved_status(char *error_string = nullptr) {
halide_mutex_lock(&saved_status_mutex);
int result = saved_status;
saved_status = halide_error_code_success;
if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) {
strncpy(error_string, MetalContextHolder::error_string, 1024);
MetalContextHolder::error_string[0] = '\0';
debug(nullptr) << "MetalContextHolder::get_and_clear_saved_status: " << error_string << "\n";
}
halide_mutex_unlock(&saved_status_mutex);
return result;
}

// Returns the previous status without clearing, and optionally copies the error string
ALWAYS_INLINE static int get_saved_status(char *error_string = nullptr) {
halide_mutex_lock(&saved_status_mutex);
int result = saved_status;
if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) {
strncpy(error_string, MetalContextHolder::error_string, 1024);
}
halide_mutex_unlock(&saved_status_mutex);
return result;
}

ALWAYS_INLINE static void set_saved_status(int new_status, char *error_string = nullptr) {
halide_mutex_lock(&saved_status_mutex);
saved_status = new_status;
if (error_string != nullptr) {
strncpy(MetalContextHolder::error_string, error_string, 1024);
shoaibkamil marked this conversation as resolved.
Show resolved Hide resolved
debug(nullptr) << "MetalContextHolder::set_saved_status: " << error_string << "\n";
}
halide_mutex_unlock(&saved_status_mutex);
}

ALWAYS_INLINE int error(char *error_string = nullptr) const {
return status || get_saved_status(error_string);
shoaibkamil marked this conversation as resolved.
Show resolved Hide resolved
}

ALWAYS_INLINE int get_and_clear_error(char *error_string = nullptr) const {
return status || get_and_clear_saved_status(error_string);
}
};

int MetalContextHolder::saved_status = halide_error_code_success;
halide_mutex MetalContextHolder::saved_status_mutex = {0};
char MetalContextHolder::error_string[1024] = {0};

struct command_buffer_completed_handler_block_descriptor_1 {
unsigned long reserved;
unsigned long block_size;
Expand All @@ -428,7 +477,23 @@ WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_complete
WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) {
shoaibkamil marked this conversation as resolved.
Show resolved Hide resolved
objc_id buffer_error = command_buffer_error(buffer);
if (buffer_error != nullptr) {
retain_ns_object(buffer_error);

// Obtain the localized NSString for the error
typedef objc_id (*localized_description_method_t)(objc_id objc, objc_sel sel);
localized_description_method_t localized_description_method = (localized_description_method_t)&objc_msgSend;
objc_id error_ns_string = (*localized_description_method)(buffer_error, sel_getUid("localizedDescription"));

// Obtain a C-style string, but do not release the NSString until reporting/printing the error
typedef char *(*utf8_string_method_t)(objc_id objc, objc_sel sel);
utf8_string_method_t utf8_string_method = (utf8_string_method_t)&objc_msgSend;
char *error_string = (*utf8_string_method)(error_ns_string, sel_getUid("UTF8String"));

ns_log_object(buffer_error);

// This is an error indicating the command buffer wasn't executed, and because it is asynchronous
// we store it in a static variable to report on the next check for an error
MetalContextHolder::set_saved_status(halide_error_code_device_run_failed, error_string);
release_ns_object(buffer_error);
}
}
Expand Down Expand Up @@ -476,7 +541,7 @@ WEAK int halide_metal_device_malloc(void *user_context, halide_buffer_t *buf) {

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

#ifdef DEBUG_RUNTIME
Expand Down Expand Up @@ -544,7 +609,7 @@ WEAK int halide_metal_device_free(void *user_context, halide_buffer_t *buf) {
WEAK int halide_metal_initialize_kernels(void *user_context, void **state_ptr, const char *source, int source_size) {
MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}
#ifdef DEBUG_RUNTIME
uint64_t t_before = halide_current_time_ns(user_context);
Expand Down Expand Up @@ -600,7 +665,7 @@ WEAK int halide_metal_device_sync(void *user_context, struct halide_buffer_t *bu

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

halide_metal_device_sync_internal(metal_context.queue, buffer);
Expand Down Expand Up @@ -651,7 +716,7 @@ WEAK int halide_metal_copy_to_device(void *user_context, halide_buffer_t *buffer

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

if (!(buffer->host && buffer->device)) {
Expand Down Expand Up @@ -695,7 +760,7 @@ WEAK int halide_metal_copy_to_host(void *user_context, halide_buffer_t *buffer)

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

halide_metal_device_sync_internal(metal_context.queue, buffer);
Expand Down Expand Up @@ -738,7 +803,7 @@ WEAK int halide_metal_run(void *user_context,

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

mtl_command_buffer *command_buffer = new_command_buffer(metal_context.queue, entry_name, strlen(entry_name));
Expand Down Expand Up @@ -962,7 +1027,7 @@ WEAK int halide_metal_buffer_copy(void *user_context, struct halide_buffer_t *sr
{
MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

debug(user_context)
Expand Down Expand Up @@ -1036,7 +1101,7 @@ WEAK int metal_device_crop_from_offset(void *user_context,
struct halide_buffer_t *dst) {
MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

dst->device_interface = src->device_interface;
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ tests(GROUPS correctness
gpu_jit_explicit_copy_to_device.cpp
gpu_large_alloc.cpp
gpu_many_kernels.cpp
gpu_metal_completion_handler_error_check.cpp
gpu_mixed_dimensionality.cpp
gpu_mixed_shared_mem_types.cpp
gpu_multi_kernel.cpp
Expand Down
59 changes: 59 additions & 0 deletions test/correctness/gpu_metal_completion_handler_error_check.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "Halide.h"
#include <unistd.h>

using namespace Halide;

bool errored = false;
void my_error(JITUserContext *, const char *msg) {
// Emitting "error.*:" to stdout or stderr will cause CMake to report the
// test as a failure on Windows, regardless of error code returned,
// hence the abbreviation to "err".
printf("Expected err: %s\n", msg);
errored = true;
}

int main(int argc, char **argv) {
Target t = get_jit_target_from_environment();
if (!t.has_feature(Target::Metal)) {
printf("[SKIP] Metal not enabled\n");
return 0;
}

Func f, g;
Var c, x, ci, xi;
RVar rxi;
RDom r(0, 1000, -327600, 327600);

// Create a function that is very costly to execute, resulting in a timeout
// on the GPU
f(x, c) = x + 0.1f * c;
f(r.x, c) += cos(sin(tan(cosh(tanh(sinh(exp(tanh(exp(log(tan(cos(exp(f(r.x, c) / cos(cosh(sinh(sin(f(r.x, c))))) / tanh(tan(tan(f(r.x, c)))))))))) + cast<float>(cast<uint8_t>(f(r.x, c) / cast<uint8_t>(log(f(r.x, c))))))))))));

f.gpu_tile(x, c, xi, ci, 4, 4);
f.update(0).gpu_tile(r.x, c, rxi, ci, 4, 4);

// Because the error handler is invoked from a Metal runtime thread, setting a custom handler for just
// this pipeline is insufficient. Instead, we set a custom handler for the JIT runtime
JITHandlers handlers;
handlers.custom_error = my_error;
Internal::JITSharedRuntime::set_default_handlers(handlers);
f.jit_handlers().custom_error = my_error;

// Metal is surprisingly resilient. Run this in a loop just to make sure we trigger the error.
for (int i = 0; (i < 10) && !errored; i++) {
auto out = f.realize({1000, 100}, t);
int result = out.device_sync();
if (result != halide_error_code_success) {
printf("Device sync failed as expected: %d\n", result);
errored = true;
}
}

if (!errored) {
printf("There was supposed to be an error\n");
return 1;
}

printf("Success!\n");
return 0;
}
Loading