Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No longer silently hide errors in Metal completion handlers (alternative approach) #8240

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/runtime/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,26 @@ WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_complete
0, sizeof(command_buffer_completed_handler_block_literal)};

WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) {
shoaibkamil marked this conversation as resolved.
Show resolved Hide resolved
halide_print(nullptr, "Completion handler invoked\n");
objc_id buffer_error = command_buffer_error(buffer);
if (buffer_error != nullptr) {
retain_ns_object(buffer_error);

// Obtain the localized NSString for the error
typedef objc_id (*localized_description_method_t)(objc_id objc, objc_sel sel);
localized_description_method_t localized_description_method = (localized_description_method_t)&objc_msgSend;
objc_id error_ns_string = (*localized_description_method)(buffer_error, sel_getUid("localizedDescription"));

// Obtain a C-style string, but do not release the NSString until reporting/printing the error
typedef char* (*utf8_string_method_t)(objc_id objc, objc_sel sel);
utf8_string_method_t utf8_string_method = (utf8_string_method_t)&objc_msgSend;
char* error_string = (*utf8_string_method)(error_ns_string, sel_getUid("UTF8String"));

ns_log_object(buffer_error);

// This is an error indicating the command buffer wasn't executed, and because it is asynchronous
// with respect to the pipeline that caused it, it is not recoverable
halide_error(nullptr, error_string);
release_ns_object(buffer_error);
}
}
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ tests(GROUPS correctness
gpu_jit_explicit_copy_to_device.cpp
gpu_large_alloc.cpp
gpu_many_kernels.cpp
gpu_metal_completion_handler_error_check.cpp
gpu_mixed_dimensionality.cpp
gpu_mixed_shared_mem_types.cpp
gpu_multi_kernel.cpp
Expand Down
51 changes: 51 additions & 0 deletions test/correctness/gpu_metal_completion_handler_error_check.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "Halide.h"
#include <unistd.h>

using namespace Halide;

bool errored = false;
void my_error(JITUserContext *, const char *msg) {
// Emitting "error.*:" to stdout or stderr will cause CMake to report the
// test as a failure on Windows, regardless of error code returned,
// hence the abbreviation to "err".
printf("Expected err: %s\n", msg);
errored = true;
}

int main(int argc, char **argv) {
Target t = get_jit_target_from_environment();
if (!t.has_feature(Target::Metal)) {
printf("[SKIP] Metal not enabled\n");
return 0;
}

Func f;
Var c, x, ci, xi;
RVar rxi;
RDom r(0, 1000, -327600, 327600);

// Create a function that is very costly to execute, resulting in a timeout
// on the GPU
f(x, c) = x + 0.1f * c;
f(r.x, c) += cos(sin(tan(cosh(tanh(sinh(exp(tanh(exp(log(tan(cos(exp(f(r.x, c) / cos(cosh(sinh(sin(f(r.x, c)))))
/ tanh(tan(tan(f(r.x, c)))))))))) + cast<float>(cast<uint8_t>(f(r.x, c) / cast<uint8_t>(log(f(r.x, c))))))))))));

f.gpu_tile(x, c, xi, ci, 4, 4);
f.update(0).gpu_tile(r.x, c, rxi, ci, 4, 4);

// Because the error handler is invoked from a Metal runtime thread, setting a custom handler for just
// this pipeline is insufficient. Instead, we set a custom handler for the JIT runtime
JITHandlers handlers;
handlers.custom_error = my_error;
Internal::JITSharedRuntime::set_default_handlers(handlers);

f.realize({1000, 100}, t);
shoaibkamil marked this conversation as resolved.
Show resolved Hide resolved

if (!errored) {
printf("There was supposed to be an error\n");
return 1;
}

printf("Success!\n");
return 0;
}
Loading