Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No longer silently hide errors in Metal completion handlers (alternative approach) #8240

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/runtime/HalideRuntimeMetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ extern uint64_t halide_metal_get_crop_offset(void *user_context, struct halide_b

struct halide_metal_device;
struct halide_metal_command_queue;
struct halide_metal_command_buffer;

/** This prototype is exported as applications will typically need to
* replace it to get Halide filters to execute on the same device and
Expand All @@ -93,6 +94,17 @@ extern int halide_metal_acquire_context(void *user_context, struct halide_metal_
*/
extern int halide_metal_release_context(void *user_context);

/** This function is called as part of the callback when a Metal command buffer completes.
* The return value, if not halide_error_code_success, will be stashed in Metal runtime and returned
* to the next call into the runtime, and the error string will be saved as well.
* The error string will be freed by the caller. The return value must be a valid Halide error code.
* This is called from the Metal driver, and thus:
* - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback
* - The function must be thread-safe
*/
extern int halide_metal_command_buffer_completion_handler(void *user_context, struct halide_metal_command_buffer *buffer,
char **returned_error_string);

#ifdef __cplusplus
} // End extern "C"
#endif
Expand Down
175 changes: 152 additions & 23 deletions src/runtime/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
extern "C" {
extern objc_id MTLCreateSystemDefaultDevice();
extern struct ObjectiveCClass _NSConcreteGlobalBlock;
extern struct ObjectiveCClass _NSConcreteStackBlock;
void *dlsym(void *, const char *);
#define RTLD_DEFAULT ((void *)-2)
}
Expand All @@ -23,8 +24,8 @@ namespace Metal {

typedef halide_metal_device mtl_device;
typedef halide_metal_command_queue mtl_command_queue;
typedef halide_metal_command_buffer mtl_command_buffer;
struct mtl_buffer;
struct mtl_command_buffer;
struct mtl_compute_command_encoder;
struct mtl_blit_command_encoder;
struct mtl_compute_pipeline_state;
Expand Down Expand Up @@ -381,6 +382,8 @@ WEAK int halide_metal_release_context(void *user_context) {

} // extern "C"

extern "C" size_t strnlen(const char *s, size_t maxlen);

namespace Halide {
namespace Runtime {
namespace Internal {
Expand All @@ -389,7 +392,10 @@ namespace Metal {
class MetalContextHolder {
objc_id pool;
void *const user_context;
int status; // must always be a valid halide_error_code_t value
int status; // must always be a valid halide_error_code_t value
static int saved_status; // must always be a valid halide_error_code_t value
static halide_mutex saved_status_mutex; // mutex for accessing saved status
static char error_string[1024];

public:
mtl_device *device;
Expand All @@ -404,11 +410,128 @@ class MetalContextHolder {
drain_autorelease_pool(pool);
}

ALWAYS_INLINE int error() const {
return status;
// We use two variants of this function: one for just checking status, and one
// that returns and clears the previous status.
ALWAYS_INLINE static int get_and_clear_saved_status(char *error_string = nullptr) {
halide_mutex_lock(&saved_status_mutex);
int result = saved_status;
saved_status = halide_error_code_success;
if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) {
strncpy(error_string, MetalContextHolder::error_string, 1024);
// Ensure null-termination, since strncpy won't if the source string is too long
error_string[1023] = '\0';
MetalContextHolder::error_string[0] = '\0';
debug(nullptr) << "MetalContextHolder::get_and_clear_saved_status: " << error_string << "\n";
}
halide_mutex_unlock(&saved_status_mutex);
return result;
}

// Returns the previous status without clearing, and optionally copies the error string
ALWAYS_INLINE static int get_saved_status(char *error_string = nullptr) {
halide_mutex_lock(&saved_status_mutex);
int result = saved_status;
if (error_string != nullptr && result != halide_error_code_success && strnlen(MetalContextHolder::error_string, 1024) > 0) {
strncpy(error_string, MetalContextHolder::error_string, 1024);
// Ensure null-termination, since strncpy won't if the source string is too long
error_string[1023] = '\0';
}
halide_mutex_unlock(&saved_status_mutex);
return result;
}

ALWAYS_INLINE static void set_saved_status(int new_status, char *error_string = nullptr) {
halide_mutex_lock(&saved_status_mutex);
saved_status = new_status;
if (error_string != nullptr) {
strncpy(MetalContextHolder::error_string, error_string, 1024);
shoaibkamil marked this conversation as resolved.
Show resolved Hide resolved
// Ensure null-termination, since strncpy won't if the source string is too long
error_string[1023] = '\0';
debug(nullptr) << "MetalContextHolder::set_saved_status: " << error_string << "\n";
}
halide_mutex_unlock(&saved_status_mutex);
}

ALWAYS_INLINE int error(char *error_string = nullptr) const {
if (status != halide_error_code_success) {
return status;
} else {
return get_saved_status(error_string);
}
}

ALWAYS_INLINE int get_and_clear_error(char *error_string = nullptr) const {
auto cleared_status = get_and_clear_saved_status(error_string);
if (status != halide_error_code_success) {
return status;
} else {
return cleared_status;
}
}
};

int MetalContextHolder::saved_status = halide_error_code_success;
halide_mutex MetalContextHolder::saved_status_mutex = {0};
char MetalContextHolder::error_string[1024] = {0};

} // namespace Metal
} // namespace Internal
} // namespace Runtime
} // namespace Halide

extern "C" {
/** This function is called as part of the callback when a Metal command buffer completes.
* The return value, if not halide_error_code_success, will be stashed in Metal runtime and returned
* to the next call into the runtime, and the error string will be saved as well.
* The error string will be freed by the caller. The return value must be a valid Halide error code.
* This is called from the Metal driver, and thus:
* - Any user_context must be preserved between the call to halide_metal_run and the corresponding callback
* - The function must be thread-safe
*/
WEAK int halide_metal_command_buffer_completion_handler(void *const user_context, mtl_command_buffer *buffer, char **returned_error_string) {
objc_id buffer_error = command_buffer_error(buffer);
if (buffer_error != nullptr) {
retain_ns_object(buffer_error);

ns_log_object(buffer_error);

// Obtain the localized NSString for the error
typedef objc_id (*localized_description_method_t)(objc_id objc, objc_sel sel);
localized_description_method_t localized_description_method = (localized_description_method_t)&objc_msgSend;
objc_id error_ns_string = (*localized_description_method)(buffer_error, sel_getUid("localizedDescription"));

retain_ns_object(error_ns_string);

// Obtain a C-style string
typedef char *(*utf8_string_method_t)(objc_id objc, objc_sel sel);
utf8_string_method_t utf8_string_method = (utf8_string_method_t)&objc_msgSend;
char *error_string = (*utf8_string_method)(error_ns_string, sel_getUid("UTF8String"));

// Copy C-style string into a fresh buffer
if (returned_error_string != nullptr) {
*returned_error_string = (char *)malloc(sizeof(char) * 1024);
if (*returned_error_string != nullptr) {
strncpy(*returned_error_string, error_string, 1024);
// Ensure null-termination, since strncpy won't if the source string is too long
(*returned_error_string)[1023] = '\0';
} else {
debug(user_context) << "halide_metal_command_buffer_completion_handler: Failed to allocate memory for error string.\n";
}
}

release_ns_object(error_ns_string);
release_ns_object(buffer_error);
return halide_error_code_device_run_failed;
}
return halide_error_code_success;
}
} // extern "C"

namespace Halide {
namespace Runtime {
namespace Internal {
namespace Metal {

struct command_buffer_completed_handler_block_descriptor_1 {
unsigned long reserved;
unsigned long block_size;
Expand All @@ -420,24 +543,23 @@ struct command_buffer_completed_handler_block_literal {
int reserved;
void (*invoke)(command_buffer_completed_handler_block_literal *, mtl_command_buffer *buffer);
struct command_buffer_completed_handler_block_descriptor_1 *descriptor;
void *const user_context;
};

WEAK command_buffer_completed_handler_block_descriptor_1 command_buffer_completed_handler_descriptor = {
0, sizeof(command_buffer_completed_handler_block_literal)};

WEAK void command_buffer_completed_handler_invoke(command_buffer_completed_handler_block_literal *block, mtl_command_buffer *buffer) {
shoaibkamil marked this conversation as resolved.
Show resolved Hide resolved
objc_id buffer_error = command_buffer_error(buffer);
if (buffer_error != nullptr) {
ns_log_object(buffer_error);
release_ns_object(buffer_error);
}
}
retain_ns_object(buffer);
char *error_string = nullptr;
void *const user_context = block->user_context;

WEAK command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = {
&_NSConcreteGlobalBlock,
(1 << 28) | (1 << 29), // BLOCK_IS_GLOBAL | BLOCK_HAS_DESCRIPTOR
0, command_buffer_completed_handler_invoke,
&command_buffer_completed_handler_descriptor};
auto status = halide_metal_command_buffer_completion_handler(user_context, buffer, &error_string);
release_ns_object(buffer);

MetalContextHolder::set_saved_status(status, error_string);
free(error_string);
}

} // namespace Metal
} // namespace Internal
Expand Down Expand Up @@ -476,7 +598,7 @@ WEAK int halide_metal_device_malloc(void *user_context, halide_buffer_t *buf) {

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

#ifdef DEBUG_RUNTIME
Expand Down Expand Up @@ -544,7 +666,7 @@ WEAK int halide_metal_device_free(void *user_context, halide_buffer_t *buf) {
WEAK int halide_metal_initialize_kernels(void *user_context, void **state_ptr, const char *source, int source_size) {
MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}
#ifdef DEBUG_RUNTIME
uint64_t t_before = halide_current_time_ns(user_context);
Expand Down Expand Up @@ -600,7 +722,7 @@ WEAK int halide_metal_device_sync(void *user_context, struct halide_buffer_t *bu

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

halide_metal_device_sync_internal(metal_context.queue, buffer);
Expand Down Expand Up @@ -651,7 +773,7 @@ WEAK int halide_metal_copy_to_device(void *user_context, halide_buffer_t *buffer

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

if (!(buffer->host && buffer->device)) {
Expand Down Expand Up @@ -695,7 +817,7 @@ WEAK int halide_metal_copy_to_host(void *user_context, halide_buffer_t *buffer)

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

halide_metal_device_sync_internal(metal_context.queue, buffer);
Expand Down Expand Up @@ -738,7 +860,7 @@ WEAK int halide_metal_run(void *user_context,

MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

mtl_command_buffer *command_buffer = new_command_buffer(metal_context.queue, entry_name, strlen(entry_name));
Expand Down Expand Up @@ -882,6 +1004,13 @@ WEAK int halide_metal_run(void *user_context,
threadsX, threadsY, threadsZ);
end_encoding(encoder);

command_buffer_completed_handler_block_literal command_buffer_completed_handler_block = {
&_NSConcreteStackBlock,
0, // must be 0 for stack blocks
0, command_buffer_completed_handler_invoke,
&command_buffer_completed_handler_descriptor,
user_context};

add_command_buffer_completed_handler(command_buffer, &command_buffer_completed_handler_block);

commit_command_buffer(command_buffer);
Expand Down Expand Up @@ -962,7 +1091,7 @@ WEAK int halide_metal_buffer_copy(void *user_context, struct halide_buffer_t *sr
{
MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

debug(user_context)
Expand Down Expand Up @@ -1036,7 +1165,7 @@ WEAK int metal_device_crop_from_offset(void *user_context,
struct halide_buffer_t *dst) {
MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
return metal_context.error();
return metal_context.get_and_clear_error();
}

dst->device_interface = src->device_interface;
Expand Down
3 changes: 2 additions & 1 deletion test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,13 @@ tests(GROUPS correctness
gpu_data_flows.cpp
gpu_different_blocks_threads_dimensions.cpp
gpu_dynamic_shared.cpp
gpu_f16_intrinsics.cpp
gpu_free_sync.cpp
gpu_give_input_buffers_device_allocations.cpp
gpu_jit_explicit_copy_to_device.cpp
gpu_large_alloc.cpp
gpu_many_kernels.cpp
gpu_f16_intrinsics.cpp
gpu_metal_completion_handler_error_check.cpp
gpu_mixed_dimensionality.cpp
gpu_mixed_shared_mem_types.cpp
gpu_multi_kernel.cpp
Expand Down
45 changes: 45 additions & 0 deletions test/correctness/gpu_metal_completion_handler_error_check.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

bool errored = false;

int main(int argc, char **argv) {
Target t = get_jit_target_from_environment();
if (!t.has_feature(Target::Metal)) {
printf("[SKIP] Metal not enabled\n");
return 0;
}

Func f, g;
Var c, x, ci, xi;
RVar rxi;
RDom r(0, 1000, -327600, 327600);

// Create a function that is very costly to execute, resulting in a timeout
// on the GPU
f(x, c) = x + 0.1f * c;
f(r.x, c) += cos(sin(tan(cosh(tanh(sinh(exp(tanh(exp(log(tan(cos(exp(f(r.x, c) / cos(cosh(sinh(sin(f(r.x, c))))) / tanh(tan(tan(f(r.x, c)))))))))) + cast<float>(cast<uint8_t>(f(r.x, c) / cast<uint8_t>(log(f(r.x, c))))))))))));

f.gpu_tile(x, c, xi, ci, 4, 4);
f.update(0).gpu_tile(r.x, c, rxi, ci, 4, 4);

// Metal is surprisingly resilient. Run this in a loop just to make sure we trigger the error.
for (int i = 0; (i < 10) && !errored; i++) {
auto out = f.realize({1000, 100}, t);
int result = out.device_sync();
if (result != halide_error_code_success) {
printf("Device sync failed as expected: %d\n", result);
errored = true;
}
}

if (!errored) {
printf("There was supposed to be an error\n");
return 1;
}

printf("Success!\n");
return 0;
}
5 changes: 5 additions & 0 deletions test/generator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,11 @@ _add_halide_libraries(metadata_tester_ucon
_add_halide_aot_tests(metadata_tester
HALIDE_LIBRARIES metadata_tester metadata_tester_ucon)

# metal_completion_handler_override_aottest.cpp
# metal_completion_handler_override_generator.cpp
_add_halide_libraries(metal_completion_handler_override FEATURES user_context)
_add_halide_aot_tests(metal_completion_handler_override)

# msan_aottest.cpp
# msan_generator.cpp
if ("${Halide_TARGET}" MATCHES "webgpu")
Expand Down
Loading
Loading