Skip to content

Commit a8dc5c9

Browse files
yzh119MasterJH5574
authored andcommitted
Fusion syntax fix + SDDMM example. (#39)
1 parent b7c7e47 commit a8dc5c9

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

python/tvm/script/tir/intrin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,5 +271,5 @@ def dense(axis: Axis, span: Optional[Span] = None):
271271

272272

273273
@register
274-
def fuse(group: List[Axis], span: Optional[Span] = None):
275-
return [FusedAxis(group, _) for _ in range(len(group))]
274+
def fuse(*group: List[Axis], span: Optional[Span] = None):
275+
return [FusedAxis(group, i) for i, _ in enumerate(group)]

python/tvm/script/tir/scope_handler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tvm.runtime import Object
2525
from tvm.ir import Span, Range
2626
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
27-
from tvm.tir.sparse import SpIterVar
27+
from tvm.tir.sparse import SpIterVar, Axis
2828

2929
from .node import BufferSlice
3030
from .utils import buffer_slice_to_region
@@ -331,10 +331,22 @@ class SparseBlock(WithScopeHandler):
331331
def __init__(self):
332332

333333
def iter(axes: List, iter_types: str, name: str = "", span: Optional[Span] = None):
334+
335+
# flatten nested axes to axes, to address the special case of fusion.
336+
def flatten_axes(axes: List[Union[Axis, List[Axis]]]) -> List[Axis]:
337+
ret = []
338+
for axis_group in axes:
339+
if isinstance(axis_group, List):
340+
ret += axis_group
341+
else:
342+
ret.append(axis_group)
343+
return ret
344+
334345
assert (
335346
self.node and self.context and self.body
336347
), "call 'exit_scope' before 'enter_scope'"
337348
block_info = self.context.block_info_stack[-1]
349+
axes = flatten_axes(axes)
338350

339351
if len(axes) != len(self.sp_iters):
340352
self.context.report_error(

tests/python/sparsetir/test_tir_sparse_lower.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,36 @@ def bmm(
346346
Z[vb, vi, vk] = Z[vb, vi, vk] + X[vb, vi, vk] * Y[vb, vk, vj]
347347

348348

349+
@T.prim_func
350+
def sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None:
351+
I = T.dense_fixed(m)
352+
J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32")
353+
K = T.dense_fixed(k)
354+
A = T.match_sparse_buffer(a, (I, K), "float32")
355+
B = T.match_sparse_buffer(b, (T.dense(J), K), "float32")
356+
C = T.match_sparse_buffer(c, (I, J), "float32")
357+
358+
with T.iter([I, J, K], "SSR", "sddmm") as [vi, vj, vk]:
359+
with T.init():
360+
C[vi, vj] = 0.
361+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
362+
363+
364+
@T.prim_func
365+
def fused_sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None:
366+
I = T.dense_fixed(m)
367+
J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32")
368+
K = T.dense_fixed(k)
369+
A = T.match_sparse_buffer(a, (I, K), "float32")
370+
B = T.match_sparse_buffer(b, (T.dense(J), K), "float32")
371+
C = T.match_sparse_buffer(c, (I, J), "float32")
372+
373+
with T.iter([T.fuse(I, J), K], "SSR", "sddmm") as [vi, vj, vk]:
374+
with T.init():
375+
C[vi, vj] = 0.
376+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
377+
378+
349379
@T.prim_func
350380
def square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32):
351381
I = T.dense_fixed(M)
@@ -616,7 +646,20 @@ def test_csr_element_wise():
616646
def test_bmm():
617647
mod = tvm.IRModule.from_expr(bmm)
618648
mod = tvm.tir.transform.LowerSparseTIR()(mod)
619-
# Todo
649+
# TODO
650+
651+
652+
def test_sddmm():
653+
mod = tvm.IRModule.from_expr(sddmm)
654+
mod = tvm.tir.transform.LowerSparseTIR()(mod)
655+
print(mod['main'].script())
656+
# TODO
657+
658+
659+
def test_fused_sddmm():
660+
mod = tvm.IRModule.from_expr(fused_sddmm)
661+
print(mod['main'].script())
662+
# TODO
620663

621664

622665
def test_square_sum():
@@ -707,6 +750,8 @@ def test_square_sum_two_K():
707750
test_bsrmm()
708751
test_ellpack_mm()
709752
test_csr_element_wise()
753+
test_sddmm()
754+
test_fused_sddmm()
710755
test_bmm()
711756
test_square_sum()
712757
test_square_sum_two_K()

0 commit comments

Comments
 (0)