Skip to content

Commit

Permalink
[WEB] Initial support for asyncify (#16694)
Browse files Browse the repository at this point in the history
This PR enables asyncify support for web runtime.

Asyncify is a feature to allow C++ to call async function in javascript.
The emcc compiler will unwind and store the stack, returning control
to JS runtime. The JS runtime needs to be able to await the promise
and then call rewind to get to the original suspended point.

This feature can be potentially useful when we would like to
call WebGPU sync in C++ runtime. As on web platform everything
have to be non-blocking.

Because asyncify can increase the wasm size by 2x, we don't enable
it by default in emcc.py and still would need to pass in options.

We will confirm potential benefit tradeoffs before turning it on by default.
Another catch is that as of now asyncify is not compatible with wasm
exception, so we temporary turn wasm-exception it off for now.
This is an item that is being worked on by emscripten so we might
be able to turn it back on later.

The testcases are added.

reference: https://emscripten.org/docs/porting/asyncify.html
  • Loading branch information
tqchen authored Mar 11, 2024
1 parent 639a6e4 commit 7ac03ca
Show file tree
Hide file tree
Showing 12 changed files with 395 additions and 26 deletions.
9 changes: 8 additions & 1 deletion python/tvm/contrib/emcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@ def create_tvmjs_wasm(output, objects, options=None, cc="emcc"):
cmd += ["-O3"]
cmd += ["-std=c++17"]
cmd += ["--no-entry"]
cmd += ["-fwasm-exceptions"]
# NOTE: asynctify conflicts with wasm-exception
# so we temp disable exception handling for now
#
# We also expect user to explicitly pass in
# -s ASYNCIFY=1 as it can increase wasm size by 2xq
#
# cmd += ["-s", "ASYNCIFY=1"]
# cmd += ["-fwasm-exceptions"]
cmd += ["-s", "WASM_BIGINT=1"]
cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"]
cmd += ["-s", "STANDALONE_WASM=1"]
Expand Down
1 change: 0 additions & 1 deletion src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ int TVMByteArrayFree(TVMByteArray* arr) {
int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args,
TVMValue* ret_val, int* ret_type_code) {
API_BEGIN();

TVMRetValue rv;
(static_cast<const PackedFuncObj*>(func))
->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv);
Expand Down
5 changes: 3 additions & 2 deletions web/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js src/tvmjs_runt

EMCC = emcc

EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes -fwasm-exceptions
EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes

EMCC_LDFLAGS = --no-entry -s WASM_BIGINT=1 -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1\
-s ERROR_ON_UNDEFINED_SYMBOLS=0 --pre-js emcc/preload.js
-s ERROR_ON_UNDEFINED_SYMBOLS=0 --pre-js emcc/preload.js\
-s ASYNCIFY=1

dist/wasm/%.bc: emcc/%.cc
@mkdir -p $(@D)
Expand Down
2 changes: 1 addition & 1 deletion web/apps/node/example.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
*/
const path = require("path");
const fs = require("fs");
const tvmjs = require("../../lib");
const tvmjs = require("../../dist/tvmjs.bundle");

const wasmPath = tvmjs.wasmPath();
const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm"));
Expand Down
1 change: 1 addition & 0 deletions web/emcc/decorate_as_wasi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

template_head = """
function EmccWASI() {
var asyncifyStubs = {};
"""

template_tail = """
Expand Down
5 changes: 5 additions & 0 deletions web/emcc/wasm_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret)
*ret = args[0];
});

TVM_REGISTER_GLOBAL("testing.call").set_body([](TVMArgs args, TVMRetValue* ret) {
(args[0].operator PackedFunc())
.CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.num_args - 1), ret);
});

TVM_REGISTER_GLOBAL("testing.ret_string").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator String();
});
Expand Down
6 changes: 5 additions & 1 deletion web/emcc/webgpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ class WebGPUDeviceAPI : public DeviceAPI {
LOG(FATAL) << "Not implemented";
}

void StreamSync(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; }
void StreamSync(Device dev, TVMStreamHandle stream) final {
static const PackedFunc* func = runtime::Registry::Get("__asyncify.WebGPUWaitForTasks");
ICHECK(func != nullptr) << "Stream sync inside c++ only supported in asyncify mode";
(*func)();
}

void SetStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; }

Expand Down
46 changes: 32 additions & 14 deletions web/src/artifact_cache.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
/*
Common Interface for the artifact cache
*/
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Common Interface for the artifact cache
*/
export interface ArtifactCacheTemplate {
/**
* fetch key url from cache
*/
fetchWithCache(url: string);
/**
* fetch key url from cache
*/
fetchWithCache(url: string);

/**
* check if cache has all keys in Cache
*/
hasAllKeys(keys: string[]);
/**
* check if cache has all keys in Cache
*/
hasAllKeys(keys: string[]);

/**
* Delete url in cache if url exists
*/
deleteInCache(url: string);
/**
* Delete url in cache if url exists
*/
deleteInCache(url: string);
}
227 changes: 227 additions & 0 deletions web/src/asyncify.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
// Helper tools to enable asynctify handling
// Thie following code is used to support wrapping of
// functins that can have async await calls in the backend runtime
// reference
// - https://kripken.github.io/blog/wasm/2019/07/16/asyncify.html
// - https://github.com/GoogleChromeLabs/asyncify
import { assert, isPromise } from "./support";

/**
* enums to check the current state of asynctify
*/
const enum AsyncifyStateKind {
None = 0,
Unwinding = 1,
Rewinding = 2
}

/** The start location of asynctify stack data */
const ASYNCIFY_DATA_ADDR = 16;
/** The data start of stack rewind/unwind */
const ASYNCIFY_DATA_START = ASYNCIFY_DATA_ADDR + 8;
/** The data end of stack rewind/unwind */
const ASYNCIFY_DATA_END = 1024;

/** Hold asynctify handler instance that runtime can use */
export class AsyncifyHandler {
/** exports from wasm */
private exports: Record<string, Function>;
/** current state kind */
private state: AsyncifyStateKind = AsyncifyStateKind.None;
/** The stored value before unwind */
private storedPromiseBeforeUnwind : Promise<any> = null;
// NOTE: asynctify do not work with exceptions
// this implementation here is mainly for possible future compact
/** The stored value that is resolved */
private storedValueBeforeRewind: any = null;
/** The stored exception */
private storedExceptionBeforeRewind: any = null;

constructor(exports: Record<string, Function>, memory: WebAssembly.Memory) {
this.exports = exports;
this.initMemory(memory);
}

// NOTE: wrapImport and wrapExport are closely related to each other
// We mark the logical jump pt in comments to increase the readability
/**
* Whether the wasm enables asynctify
* @returns Whether the wasm enables asynctify
*/
enabled(): boolean {
return this.exports.asyncify_stop_rewind !== undefined;
}

/**
* Get the current asynctify state
*
* @returns The current asynctify state
*/
getState(): AsyncifyStateKind {
return this.state;
}

/**
* Wrap a function that can be used as import of the wasm asynctify layer
*
* @param func The input import function
* @returns The wrapped function that can be registered to the system
*/
wrapImport(func: (...args: Array<any>) => any): (...args: Array<any>) => any {
return (...args: any) => {
// this is being called second time
// where we are rewinding the stack
if (this.getState() == AsyncifyStateKind.Rewinding) {
// JUMP-PT-REWIND: rewind will jump to this pt
// while rewinding the stack
this.stopRewind();
// the value has been resolved
if (this.storedValueBeforeRewind !== null) {
assert(this.storedExceptionBeforeRewind === null);
const result = this.storedValueBeforeRewind;
this.storedValueBeforeRewind = null;
return result;
} else {
assert(this.storedValueBeforeRewind === null);
const error = this.storedExceptionBeforeRewind;
this.storedExceptionBeforeRewind = null;
throw error;
}
}
// this function is being called for the first time
assert(this.getState() == AsyncifyStateKind.None);

// call the function
const value = func(...args);
// if the value is promise
// we need to unwind the stack
// so the caller will be able to evaluate the promise
if (isPromise(value)) {
// The next code step is JUMP-PT-UNWIND in wrapExport
// The value will be passed to that pt through storedPromiseBeforeUnwind
// getState() == Unwinding and we will enter the while loop in wrapExport
this.startUnwind();
assert(this.storedPromiseBeforeUnwind == null);
this.storedPromiseBeforeUnwind = value;
return undefined;
} else {
// The next code step is JUMP-PT-UNWIND in wrapExport
// normal value, we don't have to do anything
// getState() == None and we will exit while loop there
return value;
}
};
}

/**
* Warp an exported asynctify function so it can return promise
*
* @param func The input function
* @returns The wrapped async function
*/
wrapExport(func: (...args: Array<any>) => any): (...args: Array<any>) => Promise<any> {
return async (...args: Array<any>) => {
assert(this.getState() == AsyncifyStateKind.None);

// call the original function
let result = func(...args);

// JUMP-PT-UNWIND
// after calling the function
// the caller may hit a unwinding point depending on
// the if (isPromise(value)) condition in wrapImport
while (this.getState() == AsyncifyStateKind.Unwinding) {
this.stopUnwind();
// try to resolve the promise that the internal requested
// we then store it into the temp value in storedValueBeforeRewind
// which then get passed onto the function(see wrapImport)
// that can return the value
const storedPromiseBeforeUnwind = this.storedPromiseBeforeUnwind;
this.storedPromiseBeforeUnwind = null;
assert(this.storedExceptionBeforeRewind === null);
assert(this.storedValueBeforeRewind == null);

try {
this.storedValueBeforeRewind = await storedPromiseBeforeUnwind;
} catch (error) {
// the store exception
this.storedExceptionBeforeRewind = error;
}
assert(!isPromise(this.storedValueBeforeRewind));
// because we called asynctify_stop_unwind,the state is now none
assert(this.getState() == AsyncifyStateKind.None);

// re-enter the function, jump to JUMP-PT-REWIND in wrapImport
// the value will be passed to that point via storedValueBeforeRewind
//
// NOTE: we guarantee that if exception is throw the asynctify state
// will already be at None, this is because we will goto JUMP-PT-REWIND
// which will call aynctify_stop_rewind
this.startRewind();
result = func(...args);
}
return result;
};
}

private startRewind() : void {
if (this.exports.asyncify_start_rewind === undefined) {
throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc");
}
this.exports.asyncify_start_rewind(ASYNCIFY_DATA_ADDR);
this.state = AsyncifyStateKind.Rewinding;
}

private stopRewind() : void {
if (this.exports.asyncify_stop_rewind === undefined) {
throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc");
}
this.exports.asyncify_stop_rewind();
this.state = AsyncifyStateKind.None;
}

private startUnwind() : void {
if (this.exports.asyncify_start_unwind === undefined) {
throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc");
}
this.exports.asyncify_start_unwind(ASYNCIFY_DATA_ADDR);
this.state = AsyncifyStateKind.Unwinding;
}

private stopUnwind() : void {
if (this.exports.asyncify_stop_unwind === undefined) {
throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc");
}
this.exports.asyncify_stop_unwind();
this.state = AsyncifyStateKind.None;
}
/**
* Initialize the wasm memory to setup necessary meta-data
* for asynctify handling
* @param memory The memory ti
*/
private initMemory(memory: WebAssembly.Memory): void {
// Set the meta-data at address ASYNCTIFY_DATA_ADDR
new Int32Array(memory.buffer, ASYNCIFY_DATA_ADDR, 2).set(
[ASYNCIFY_DATA_START, ASYNCIFY_DATA_END]
);
}
}
Loading

0 comments on commit 7ac03ca

Please sign in to comment.