Skip to content

Commit

Permalink
[RUNTIME][METAL] Provide richer runtime when error happens (#16713)
Browse files Browse the repository at this point in the history
This PR enhances metal runtime to include more error messages
when error happens.
  • Loading branch information
tqchen authored Mar 14, 2024
1 parent 071fb8a commit 0978ab6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 11 deletions.
27 changes: 19 additions & 8 deletions src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>

#include "../workspace_pool.h"
Expand Down Expand Up @@ -106,25 +107,35 @@ class AutoReleasePoolWrapper {
*/
class Stream {
public:
explicit Stream(id<MTLDevice> device) : error_happened_(false) {
queue_ = [device newCommandQueue];
}
explicit Stream(id<MTLDevice> device) { queue_ = [device newCommandQueue]; }
~Stream() { [queue_ release]; }
id<MTLCommandBuffer> GetCommandBuffer() {
id<MTLCommandBuffer> GetCommandBuffer(bool attach_error_callback = true) {
id<MTLCommandBuffer> cb = [queue_ commandBuffer];
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus();
if (buffer.status == MTLCommandBufferStatusError) {
ICHECK(buffer.error != nil);
this->SetError(buffer.error.localizedDescription.UTF8String);
}
}];
return cb;
}
bool HasErrorHappened() { return error_happened_; }

void SetError(std::string error_description) {
error_happened_ = true;
error_description_ = std::move(error_description);
}

bool HasErrorHappened() const { return error_happened_; }

const std::string& ErrorDescription() const { return error_description_; }

private:
void SetErrorStatus() { error_happened_ = true; }
// Queue
id<MTLCommandQueue> queue_;
// Check if error happened in one previous run
bool error_happened_;
bool error_happened_{false};
// error description
std::string error_description_;
};

/*!
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ int GetWarpSize(id<MTLDevice> dev) {
if (dev_from.device_type == kDLCPU) dev = dev_to;
Stream* s = this->CastStreamOrGetDefault(stream, dev.device_id);
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream";
LOG(FATAL) << "GPUError: " << s->ErrorDescription();
}
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
int from_dev_type = static_cast<int>(dev_from.device_type);
Expand Down Expand Up @@ -301,7 +301,7 @@ int GetWarpSize(id<MTLDevice> dev) {
[cb commit];
[cb waitUntilCompleted];
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned!";
LOG(FATAL) << "GPUError: " << s->ErrorDescription();
}
};
}
Expand Down
16 changes: 15 additions & 1 deletion src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,19 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
// obtain the stream
auto stream =
metal::MetalWorkspace::Global()->CastStreamOrGetDefault(t->stream[device_id], device_id);

// skip launching so the error can be printed during sync
if (stream->HasErrorHappened()) return;

if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
}
ThreadWorkLoad wl = launch_param_config_.Extract(args);
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup;
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
id<MTLCommandBuffer> cb = stream->GetCommandBuffer();
// attach error message directly in this functio
id<MTLCommandBuffer> cb = stream->GetCommandBuffer(/* attach_error_callback= */ false);
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
[encoder setComputePipelineState:scache_[device_id]];
for (size_t i = 0; i < num_buffer_args_; ++i) {
Expand All @@ -219,6 +223,16 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), wl.block_dim(2));
[encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
[encoder endEncoding];
// attach error message with function name
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
if (buffer.status == MTLCommandBufferStatusError) {
ICHECK(buffer.error != nil);
std::ostringstream os;
os << "GPUError happens after running " << func_name_ << ": "
<< buffer.error.localizedDescription.UTF8String;
stream->SetError(os.str());
}
}];
[cb commit];
};
}
Expand Down

0 comments on commit 0978ab6

Please sign in to comment.