Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tools for RPC testing #1

Merged
merged 1 commit into from
Mar 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def debug_dump(mod, name, args):
"""Debug dump mode"""
if not args.debug_dump:
return
dump_path = os.path.join(
args.artifact_path, "debug", name)
dump_path = os.path.join(args.artifact_path, "debug", name)
with open(dump_path, "w") as outfile:
outfile.write(mod.script(show_meta=True))
print(f"Dump mod to {dump_path}")


def build(mod: tvm.IRModule, args: Dict) -> None:
from tvm import meta_schedule as ms

Expand Down
1 change: 1 addition & 0 deletions deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __call__(self, prompt: str, negative_prompt: str = ""):
for i in tqdm(range(num_inference_steps)):
t = self.scheduler.timesteps[i]
self.debug_dump(f"unet_input_{i}", latents)
self.debug_dump(f"timestep_{i}", t)
noise_pred = self.unet_latents_to_noise_pred(latents, t, text_embeddings)
self.debug_dump(f"unet_output_{i}", noise_pred)
latents = self.scheduler.step(self.vm, noise_pred, latents, i)
Expand Down
2 changes: 1 addition & 1 deletion scripts/local_deploy_site.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ scripts/build_site.sh web/local-config.json

echo "symlink parameter location to site.."

ln -s `pwd`/dist/params site/_site/dist/web-sd-shards-v1-5
ln -s `pwd`/dist/params site/_site/web-sd-shards-v1-5
cd site && jekyll serve --skip-initial-build --host localhost --baseurl /web-stable-diffusion --port 8888
22 changes: 22 additions & 0 deletions scripts/rpc_debug_deploy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/bash
set -euxo pipefail

TVM_HOME_SET="${TVM_HOME:-}"

if [[ -z ${TVM_HOME_SET} ]]; then
echo "Require TVM_HOME to be set"
exit 255
fi

echo "Copy files..."
mkdir -p ${TVM_HOME}/web/dist/www/dist/
cp web/stable_diffusion.html ${TVM_HOME}/web/dist/www/rpc_plugin.html
cp web/stable_diffusion.js ${TVM_HOME}/web/dist/www/dist/
cp web/local-config.json ${TVM_HOME}/web/dist/www/stable-diffusion-config.json

cp dist/scheduler_consts.json ${TVM_HOME}/web/dist/www/dist/
cp dist/stable_diffusion_webgpu.wasm ${TVM_HOME}/web/dist/www/dist/
cp -rf dist/tokenizers-wasm ${TVM_HOME}/web/dist/www/dist/

rm -rf ${TVM_HOME}/web/.ndarray_cache/web-sd-shards-v1-5
ln -s `pwd`/dist/params ${TVM_HOME}/web/.ndarray_cache/web-sd-shards-v1-5
4 changes: 3 additions & 1 deletion site/_includes/head.html
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
<meta http-equiv="origin-trial" content="Agx76XA0ITxMPF0Z8rbbcMllwuxsyp9qdtQaXlLqu1JUrdHB6FPonuyIKJ3CsBREUkeioJck4nn3KO0c0kkwqAMAAABJeyJvcmlnaW4iOiJodHRwOi8vbG9jYWxob3N0Ojg4ODgiLCJmZWF0dXJlIjoiV2ViR1BVIiwiZXhwaXJ5IjoxNjkxNzExOTk5fQ==">
<meta http-equiv="origin-trial" content="AnmwqQ1dtYDQTYkZ5iMtHdINCaxjE94uWQBKp2yOz1wPTcjSRtOHUGQG+r2BxsEuM0qhxTVnuTjyh31HgTeA8gsAAABZeyJvcmlnaW4iOiJodHRwczovL21sYy5haTo0NDMiLCJmZWF0dXJlIjoiV2ViR1BVIiwiZXhwaXJ5IjoxNjkxNzExOTk5LCJpc1N1YmRvbWFpbiI6dHJ1ZX0=">
<meta http-equiv="origin-trial" content="AnmwqQ1dtYDQTYkZ5iMtHdINCaxjE94uWQBKp2yOz1wPTcjSRtOHUGQG+r2BxsEuM0qhxTVnuTjyh31HgTeA8gsAAABZeyJvcmlnaW4iOiJodHRwczovL21sYy5haTo0NDMiLCJmZWF0dXJlIjoiV2ViR1BVIiwiZXhwaXJ5IjoxNjkxNzExOTk5LCJpc1N1YmRvbWFpbiI6dHJ1ZX0=">
<script src="dist/tvmjs_runtime.wasi.js"></script>
<script src="dist/tvmjs.bundle.js"></script>
62 changes: 62 additions & 0 deletions tests/debug/rpc_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import json
import numpy as np
import argparse
from web_stable_diffusion.rpc_testing import WebGPUDebugSession


def load_checkpt(args, name):
return np.load(os.path.join(args.artifact_path, "debug", f"{name}.npy"))


def load_metadata(args):
json_path = os.path.join(args.artifact_path, "params", "ndarray-cache.json")
return json.load(open(json_path, "r"))["metadata"]


def main_vae(args):
sess = WebGPUDebugSession(
os.path.join(args.artifact_path, "stable_diffusion_webgpu.wasm")
)
vae_input = load_checkpt(args, "vae_input")
vae_output = load_checkpt(args, "vae_output")
nparams = load_metadata(args)["vaeParamSize"]
vae = sess.get_wrapper("vae", nparams, time_eval=args.time_eval)
result = vae(vae_input)
# result are 0-255, set atol=1 so we can avoid relative minor error
np.testing.assert_allclose(result, vae_output, atol=1)


def main_unet(args):
sess = WebGPUDebugSession(
os.path.join(args.artifact_path, "stable_diffusion_webgpu.wasm")
)
unet_input = load_checkpt(args, f"unet_input_{args.counter}")
text_embeddings = load_checkpt(args, f"text_embeddings")
timestep = load_checkpt(args, f"timestep_{args.counter}")
unet_output = load_checkpt(args, f"unet_output_{args.counter}")
nparams = load_metadata(args)["unetParamSize"]
unet = sess.get_wrapper("unet", nparams, time_eval=args.time_eval)
result = unet(unet_input, timestep, text_embeddings)
np.testing.assert_allclose(result, unet_output, atol=4e-5)


def _parse_args():
args = argparse.ArgumentParser()
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument("--stage", type=str, choices=["unet", "vae"], required=True)
args.add_argument("--counter", type=int, default=0)
args.add_argument("--time-eval", default=False, action="store_true")
parsed = args.parse_args()
return parsed


if __name__ == "__main__":
args = _parse_args()
if args.stage == "vae":
main_vae(args)
elif args.stage == "unet":
main_unet(args)
else:
raise ValueError(f"Unknown stage {args.stage}")
print(f"All pass running stage {args.stage}, counter={args.counter}")
2 changes: 1 addition & 1 deletion web/local-config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"schedulerConstUrl": "dist/scheduler_consts.json",
"wasmUrl": "dist/stable_diffusion_webgpu.wasm",
"cacheUrl": "dist/web-sd-shards-v1-5/",
"cacheUrl": "web-sd-shards-v1-5/",
"tokenizer": "openai/clip-vit-large-patch14"
}
3 changes: 0 additions & 3 deletions web/stable_diffusion.html
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
<script src="dist/tvmjs_runtime.wasi.js"></script>
<script src="dist/tvmjs.bundle.js"></script>

<script>
var tvmjsGlobalEnv = tvmjsGlobalEnv || {};
</script>
Expand Down
9 changes: 4 additions & 5 deletions web/stable_diffusion.js
Original file line number Diff line number Diff line change
Expand Up @@ -441,19 +441,18 @@ class StableDiffusionInstance {
}
this.tvm = tvmInstance;

await this.#asyncInitConfig();
await this.#asyncInitPipeline(this.config.schedulerConstUrl, this.config.tokenizer);

this.tvm.beginScope();
this.tvm.registerAsyncServerFunc("generate", async (prompt, vaeCycle) => {
document.getElementById("inputPrompt").value = prompt;
const negPrompt = "";
document.getElementById("negativePrompt").value = "";
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), vaeCycle);
});

this.tvm.registerAsyncServerFunc("clearCanvas", async () => {
this.pipeline.clearCanvas();
this.tvm.clearCanvas();
});
this.tvm.registerAsyncServerFunc("showImage", async (data) => {
this.tvm.showImage(data);
});
this.tvm.endScope();
}
Expand Down
143 changes: 143 additions & 0 deletions web_stable_diffusion/rpc_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Testing utilities through rpc."""
import os
import tvm
import numpy as np

from tvm import relax, rpc

try:
import torch
except ImportError:
pass


class RPCBaseDebugSession:
"""A helper class to create debug sessions through rpc.

Parameters
----------
remote: RPCSession
The input rpc session.

rt_mod: runtime.Module
Runtime module that loads the VM

device: runtime.Device
Device to run the vm on.
"""

def __init__(self, remote, rt_mod, device):
self.remote = remote
self.vm = relax.VirtualMachine(rt_mod, device)
self.device = device

def get_wrapper(self, func_name, nparam_cached=0, time_eval=False):
"""Get remote debug wrapper.

Parameters
----------
func_name: str
The function name

nparam_cached: int
Number of extra parameters

time_eval: bool
Whether perform time eval

Returns
-------
wrapper: Callable
The callable that can be used to run related items.
"""
pfunc = None
if nparam_cached != 0:
pfunc_from_cache = self.remote.get_function("tvmjs.param_module_from_cache")
pfunc = pfunc_from_cache(func_name, nparam_cached)

if isinstance(time_eval, dict):
time_eval_kwargs = time_eval
else:
time_eval_kwargs = {} if time_eval else None

time_eval_result = []

def wrapped_f(*args):
new_args = []
ret_kind = None
ret_device = None

for arg in args:
if isinstance(arg, tvm.nd.NDArray):
ret_kind = tvm.nd.NDArray
ret_device = arg.device
if arg.device != self.device:
arg = arg.copyto(tvm.cpu()).copyto(self.device)
elif isinstance(arg, np.ndarray):
ret_kind = np.ndarray
arg = tvm.nd.array(arg, self.device)
elif isinstance(arg, torch.Tensor):
ret_kind = torch.Tensor
ret_device = arg.device
arg = tvm.nd.array(arg.numpy(), self.device)
new_args.append(arg)
if pfunc:
new_args.append(pfunc)

if pfunc:
self.vm.module["set_input_with_param_module"](func_name, *new_args)
else:
self.vm.module["set_input"](func_name, *new_args)

self.vm.invoke_stateful(func_name)
if time_eval_kwargs is not None and len(time_eval_result) == 0:
res = self.vm.time_evaluator("invoke_stateful", self.device)(
func_name, **time_eval_kwargs
)
time_eval_result.append(res)
print(f"Remote[{func_name}] on {self.devcice}, {res}")

outputs = self.vm.get_outputs(func_name)

def _convert_return(data):
if isinstance(data, (tvm.ir.Array, list, tuple)):
return [_convert_return(x) for x in data]
if not isinstance(data, tvm.nd.NDArray):
return data
if ret_kind == tvm.nd.NDArray:
if ret_device == data.device:
return data
return data.copyto(tvm.cpu()).copyto(ret_device)
if ret_kind == torch.Tensor:
return torch.from_numpy(data.numpy()).to(ret_device)
if ret_kind == np.ndarray:
return data.numpy()
raise ValueError(f"Unknown ret kind {ret_kind}")

return _convert_return(outputs)

return wrapped_f


class WebGPUDebugSession(RPCBaseDebugSession):
"""Remote debug session to handle webgpu.

Parameters
----------
wasm_path: str
The path to the wasm.
"""

def __init__(self, wasm_path):
proxy_host = os.environ.get("TVM_RPC_PROXY_HOST", "127.0.0.1")
proxy_port = int(os.environ.get("TVM_RPC_PROXY_PORT", "9090"))
wasm_binary = open(wasm_path, "rb").read()
remote = rpc.connect(
proxy_host,
proxy_port,
key="wasm",
session_constructor_args=["rpc.WasmSession", wasm_binary],
)
super(WebGPUDebugSession, self).__init__(
remote, remote.system_lib(), remote.webgpu()
)