Skip to content

Commit

Permalink
Fix after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Oct 29, 2023
1 parent 3184a80 commit 393aaa3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 16 deletions.
19 changes: 4 additions & 15 deletions src/target/source/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,6 @@ std::string CodeGenWebGPU::Finish() {

void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
// skip the first underscore, so SSA variable starts from
name_supply_->FreshName("v_");
// Setup the thread group info.
ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");

// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
Expand Down Expand Up @@ -159,7 +153,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re
std::ostringstream os_param_access;
os_param_access << "paramWriteAccess:[";
// setup buffer argumemts
for (Var arg : func->params) {
for (Var arg : f->params) {
DataType t = arg.dtype();
func_info.arg_types.push_back(t);

Expand Down Expand Up @@ -706,19 +700,14 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
functions.Set(gvar, prim_func);
}

std::unordered_map<std::string, std::string> smap;
for (auto [gvar, prim_func] : functions) {
CodeGenWebGPU cg(target);
std::string f_name = global_symbol.value();
cg.Init(output_ssa);
fmap[f_name] = cg.AddFunction(f, skip_readonly_decl);
std::string code = cg.Finish();
smap[cg.GetFunctionName(gvar)] = code;
smap[f_name] = code;
}

auto n = make_object<WebGPUSourceModuleNode>(smap, fmap);
Expand Down
2 changes: 1 addition & 1 deletion web/tests/python/webgpu_rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_rpc():
temp = utils.tempdir()

wasm_path = temp.relpath("addone_gpu.wasm")
fadd.export_library(wasm_path, fcompile=emcc.create_tvmjs_wasm)
fadd.export_library(wasm_path, fcompile=tvmjs.create_tvmjs_wasm)

wasm_binary = open(wasm_path, "rb").read()
remote = rpc.connect(
Expand Down

0 comments on commit 393aaa3

Please sign in to comment.