Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/transform/inject_fence_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ bool IsAsyncIntrinsic(const CallNode *call) {
return true;
}

// wgmma async intrinsics
if (call->op.same_as(tl_gemm()) || call->op.same_as(tl_gemm_sp())) {
return true;
}

return false;
}

Expand Down Expand Up @@ -208,8 +213,10 @@ class ProxyFenceInjector : public StmtMutator {
} else if (IsKnownGeneric(call)) {
kind = ProxyKind::kGeneric;
} else {
// Treat unknown externs as async to avoid missing required fences.
kind = ProxyKind::kAsync;
// We can now treat extern as Generic, since gemm and gemm_sp are never
// represented as call_extern nodes. They are call_intrin nodes and will
// be handled by IsAsyncIntrinsic above.
kind = ProxyKind::kGeneric;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def before():
C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
Expand All @@ -45,7 +46,8 @@ def after():
for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.fence_proxy_async()
T.call_extern("handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
Expand Down Expand Up @@ -169,7 +171,6 @@ def before():
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)

order = []

def visit(node):
Expand All @@ -185,43 +186,5 @@ def visit(node):
assert order.index("tl.fence_proxy_async") < order.index("tl.ptx_wgmma_ss")


def test_wgmma_after_descriptor():

@T.prim_func
def before():
with T.Kernel(1):
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor")
C_local = T.decl_buffer((32,), "float16", scope="local")
T.initialize_descriptor(desc_a, T.uint64(0), 2, 1, 32)
T.initialize_descriptor(desc_b, T.uint64(0), 2, 1, 32)
T.warpgroup_arrive()
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16",
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data,
T.int32(0), T.bool(True), 1, 1)

mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)

fence_count = 0
order = []

def visit(node):
nonlocal fence_count
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
name = getattr(call.op, "name", "")
order.append(name)
if name == "tl.fence_proxy_async":
fence_count += 1

tir.stmt_functor.post_order_visit(mod["main"].body, visit)
assert fence_count >= 1
assert "tl.warpgroup_arrive" in order
assert order.index("tl.fence_proxy_async") < order.index("tl.warpgroup_arrive")


if __name__ == "__main__":
tilelang.testing.main()