@@ -346,6 +346,36 @@ def bmm(
346346 Z [vb , vi , vk ] = Z [vb , vi , vk ] + X [vb , vi , vk ] * Y [vb , vk , vj ]
347347
348348
349+ @T .prim_func
350+ def sddmm (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle , m : T .int32 , n : T .int32 , k : T .int32 , nnz : T .int32 ) -> None :
351+ I = T .dense_fixed (m )
352+ J = T .sparse_variable (I , (n , nnz ), (indptr , indices ), "int32" )
353+ K = T .dense_fixed (k )
354+ A = T .match_sparse_buffer (a , (I , K ), "float32" )
355+ B = T .match_sparse_buffer (b , (T .dense (J ), K ), "float32" )
356+ C = T .match_sparse_buffer (c , (I , J ), "float32" )
357+
358+ with T .iter ([I , J , K ], "SSR" , "sddmm" ) as [vi , vj , vk ]:
359+ with T .init ():
360+ C [vi , vj ] = 0.
361+ C [vi , vj ] = C [vi , vj ] + A [vi , vk ] * B [vj , vk ]
362+
363+
364+ @T .prim_func
365+ def fused_sddmm (a : T .handle , b : T .handle , c : T .handle , indptr : T .handle , indices : T .handle , m : T .int32 , n : T .int32 , k : T .int32 , nnz : T .int32 ) -> None :
366+ I = T .dense_fixed (m )
367+ J = T .sparse_variable (I , (n , nnz ), (indptr , indices ), "int32" )
368+ K = T .dense_fixed (k )
369+ A = T .match_sparse_buffer (a , (I , K ), "float32" )
370+ B = T .match_sparse_buffer (b , (T .dense (J ), K ), "float32" )
371+ C = T .match_sparse_buffer (c , (I , J ), "float32" )
372+
373+ with T .iter ([T .fuse (I , J ), K ], "SSR" , "sddmm" ) as [vi , vj , vk ]:
374+ with T .init ():
375+ C [vi , vj ] = 0.
376+ C [vi , vj ] = C [vi , vj ] + A [vi , vk ] * B [vj , vk ]
377+
378+
349379@T .prim_func
350380def square_sum (a : T .handle , b : T .handle , indptr_j : T .handle , indices_j : T .handle , indptr_k : T .handle , indices_k : T .handle , nnz_j : T .int32 , nnz_k : T .int32 , M : T .int32 , N1 : T .int32 , N2 : T .int32 ):
351381 I = T .dense_fixed (M )
@@ -616,7 +646,20 @@ def test_csr_element_wise():
616646def test_bmm ():
617647 mod = tvm .IRModule .from_expr (bmm )
618648 mod = tvm .tir .transform .LowerSparseTIR ()(mod )
619- # Todo
649+ # TODO
650+
651+
652+ def test_sddmm ():
653+ mod = tvm .IRModule .from_expr (sddmm )
654+ mod = tvm .tir .transform .LowerSparseTIR ()(mod )
655+ print (mod ['main' ].script ())
656+ # TODO
657+
658+
659+ def test_fused_sddmm ():
660+ mod = tvm .IRModule .from_expr (fused_sddmm )
661+ print (mod ['main' ].script ())
662+ # TODO
620663
621664
622665def test_square_sum ():
@@ -707,6 +750,8 @@ def test_square_sum_two_K():
707750 test_bsrmm ()
708751 test_ellpack_mm ()
709752 test_csr_element_wise ()
753+ test_sddmm ()
754+ test_fused_sddmm ()
710755 test_bmm ()
711756 test_square_sum ()
712757 test_square_sum_two_K ()
0 commit comments