Skip to content

Commit

Permalink
[Unity][WEBGPU] Enable wasm exception propagation (#16330)
Browse files Browse the repository at this point in the history
This PR enables wasm exception propagation among
c++ runtime generated wasm and javascript.

Right now the error.message is passed back
this would allow us to do some handling in webgpu
related exceptions raised through FFI boundaries.

Note that this would require the latest emscripten
and on the nodejs, --experimental-wasm-eh support.
  • Loading branch information
tqchen authored Jan 4, 2024
1 parent d509661 commit 49fc613
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 7 deletions.
1 change: 1 addition & 0 deletions python/tvm/contrib/emcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def create_tvmjs_wasm(output, objects, options=None, cc="emcc"):
cmd += ["-O3"]
cmd += ["-std=c++17"]
cmd += ["--no-entry"]
cmd += ["-fwasm-exception"]
cmd += ["-s", "WASM_BIGINT=1"]
cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"]
cmd += ["-s", "STANDALONE_WASM=1"]
Expand Down
2 changes: 1 addition & 1 deletion web/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js src/tvmjs_runt

EMCC = emcc

EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes
EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes -fwasm-exceptions

EMCC_LDFLAGS = --no-entry -s WASM_BIGINT=1 -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1\
-s ERROR_ON_UNDEFINED_SYMBOLS=0 --pre-js emcc/preload.js
Expand Down
11 changes: 9 additions & 2 deletions web/emcc/tvmjs_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,15 @@ class AsyncLocalSession : public LocalSession {
int code = args[0];
TVMRetValue rv;
rv = args[1];
this->EncodeReturn(std::move(rv),
[&](TVMArgs encoded_args) { callback(RPCCode::kReturn, encoded_args); });
if (code == static_cast<int>(RPCCode::kReturn)) {
this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) {
callback(RPCCode::kReturn, encoded_args);
});
} else {
// for exception, we can pass through as since this is just normal encoding.
ICHECK_EQ(code, static_cast<int>(RPCCode::kException));
callback(RPCCode::kException, args);
}
});

TVMRetValue temp;
Expand Down
2 changes: 1 addition & 1 deletion web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"build": "rollup -c",
"lint": "eslint -c .eslintrc.json .",
"typedoc": "typedoc src/index.ts --plugin typedoc-plugin-missing-exports",
"test": "jest",
"test": "node --experimental-wasm-eh node_modules/.bin/jest",
"bundle": "npm run build && cp lib/index.js dist/index.js && cp lib/index.js dist/tvmjs.bundle.js",
"example": "npm run bundle && node apps/node/example.js",
"example:wasi": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_example.js",
Expand Down
5 changes: 5 additions & 0 deletions web/src/ctypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ export type PtrOffset = number;
*/
export type FTVMGetLastError = () => Pointer;

/**
* void TVMAPISetLastError(const char* msg);
*/
export type FTVMAPISetLastError = (msg: Pointer) => void;

/**
* int TVMModGetFunction(TVMModuleHandle mod,
* const char* func_name,
Expand Down
31 changes: 28 additions & 3 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class FFILibrary implements Disposable {
if (code != 0) {
const msgPtr = (this.exports
.TVMGetLastError as ctypes.FTVMGetLastError)();
console.log("Here");
throw new Error("TVMError: " + this.memory.loadCString(msgPtr));
}
}
Expand Down Expand Up @@ -1902,10 +1903,15 @@ export class Instance implements Disposable {
// need to keep it alive until callback is fulfilled.
const callback = this.detachFromCurrentScope(args[args.length - 1] as PackedFunc);
const promise: Promise<any> = func(...fargs);
promise.then((rv: any) => {
const onFulfilled = (rv: any) => {
callback(this.scalar(AsyncCallbackCode.kReturn, "int32"), rv);
callback.dispose();
});
};
const onRejected = (reason: any) => {
callback(this.scalar(AsyncCallbackCode.kException, "int32"), reason.toString());
callback.dispose();
};
promise.then(onFulfilled, onRejected);
};
this.registerFunc("__async." + name, asyncVariant, override);
}
Expand Down Expand Up @@ -2216,7 +2222,26 @@ export class Instance implements Disposable {
jsArgs.push(this.retValueToJS(valuePtr, tcode, true));
}

const rv = func(...jsArgs);
let rv: any;
try {
rv = func(...jsArgs);
} catch (error) {
// error handling
// store error via SetLastError
this.ctx.endScope();
const errMsg = "JSCallbackError: " + error.message;
const stack = lib.getOrAllocCallStack();
const errMsgOffset = stack.allocRawBytes(errMsg.length + 1);
stack.storeRawBytes(errMsgOffset, StringToUint8Array(errMsg));
stack.commitToWasmMemory();
(this.lib.exports.TVMAPISetLastError as ctypes.FTVMAPISetLastError)(
stack.ptrFromOffset(errMsgOffset)
);
this.lib.recycleCallStack(stack);
return -1;
}

// normal return path
// recycle all js object value in function unless we want to retain them.
this.ctx.endScope();

Expand Down
15 changes: 15 additions & 0 deletions web/tests/node/test_packed_func.js
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ test("RegisterGlobal", () => {
tvm.endScope();
});

test("ExceptionPassing", () => {
tvm.beginScope();
tvm.registerFunc("throw_error", function (msg) {
throw Error(msg);
});
let f = tvm.getGlobalFunc("throw_error");
try {
f("error-xyz");
throw Error("error not caught");
} catch (error) {
assert(error.message.indexOf("error-xyz") != -1);
}
tvm.endScope();
});

test("NDArrayCbArg", () => {
tvm.beginScope();
let use_count = tvm.getGlobalFunc("testing.object_use_count");
Expand Down

0 comments on commit 49fc613

Please sign in to comment.