1919import tvm .tir as tir
2020import scipy .sparse as sp
2121import numpy as np
22+ import pytest
2223from tvm .script import tir as T
2324
2425
@@ -367,7 +368,7 @@ def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j:
367368 J_indices = T .match_buffer (indices_j , [nnz_j ], dtype = "int32" )
368369 K_indptr = T .match_buffer (indptr_k , [nnz_j + 1 ], dtype = "int32" )
369370 K_indices = T .match_buffer (indices_k , [nnz_k ], dtype = "int32" )
370-
371+
371372 for v_vi in T .serial (0 , M ):
372373 with T .block ("square_sum_2" ):
373374 vi = T .axis .spatial (M , v_vi )
@@ -391,6 +392,58 @@ def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j:
391392 B_data [vi ] = B_data [vi ] + A_data [K_indptr [J_indptr [vi ] + vj ] + vk ]
392393
393394
395+ @T .prim_func
396+ def square_sum_two_K (a : T .handle , b : T .handle , indptr_j : T .handle , indices_j : T .handle , indptr_k0 : T .handle , indices_k0 : T .handle , indptr_k1 : T .handle , indices_k1 : T .handle , nnz_j : T .int32 , nnz_k : T .int32 , M : T .int32 , N1 : T .int32 , N2 : T .int32 ):
397+ # Used only for testing `GetIndicesRange()`.
398+ # Currently it is ensured that `indptr_k0` is the same as `indptr_k1`, and `indices_k0` is the
399+ # same as `indices_k1`.
400+ I = T .dense_fixed (M )
401+ J = T .sparse_variable (I , (N1 , nnz_j ), (indptr_j , indices_j ), "int32" )
402+ K0 = T .sparse_variable (J , (N2 , nnz_k ), (indptr_k0 , indices_k0 ), "int32" )
403+ K1 = T .sparse_variable (J , (N2 , nnz_k ), (indptr_k1 , indices_k1 ), "int32" )
404+ A = T .match_sparse_buffer (a , (I , J , K0 ), "float32" )
405+ B = T .match_sparse_buffer (b , (I ,), "float32" )
406+
407+ with T .iter ([I , J , K1 ], "SRR" , "square_sum" ) as [vi , vj , vk ]:
408+ with T .init ():
409+ B [vi ] = 0.0
410+ B [vi ] = B [vi ] + A [vi , vj , vk ]
411+
412+
413+ @T .prim_func
414+ def lowered_square_sum_two_K (a : T .handle , b : T .handle , indptr_j : T .handle , indices_j : T .handle , indptr_k0 : T .handle , indices_k0 : T .handle , indptr_k1 : T .handle , indices_k1 : T .handle , nnz_j : T .int32 , nnz_k : T .int32 , M : T .int32 , N1 : T .int32 , N2 : T .int32 ) -> None :
415+ A_data = T .match_buffer (a , [nnz_k ], dtype = "float32" )
416+ B_data = T .match_buffer (b , [M ], dtype = "float32" )
417+ J_indptr = T .match_buffer (indptr_j , [M + 1 ], dtype = "int32" )
418+ J_indices = T .match_buffer (indices_j , [nnz_j ], dtype = "int32" )
419+ K0_indptr = T .match_buffer (indptr_k0 , [nnz_j + 1 ], dtype = "int32" )
420+ K0_indices = T .match_buffer (indices_k0 , [nnz_k ], dtype = "int32" )
421+ K1_indptr = T .match_buffer (indptr_k1 , [nnz_j + 1 ], dtype = "int32" )
422+ K1_indices = T .match_buffer (indices_k1 , [nnz_k ], dtype = "int32" )
423+
424+ for v_vi in T .serial (0 , M ):
425+ with T .block ("square_sum_2" ):
426+ vi = T .axis .spatial (M , v_vi )
427+ T .reads ([J_indptr [0 : M + 1 ], J_indices [0 : nnz_j ], K0_indptr [0 : nnz_j + 1 ], K0_indices [0 : nnz_k ], K1_indptr [0 : nnz_j + 1 ], K1_indices [0 : nnz_k ], A_data [0 : nnz_k ], B_data [0 : M ]])
428+ T .writes ([B_data [0 : M ]])
429+ T .block_attr ({"sparse" :True })
430+ for v_vj in T .serial (0 , J_indptr [v_vi + 1 ] - J_indptr [v_vi ]):
431+ with T .block ("square_sum_1" ):
432+ vj = T .axis .reduce (J_indptr [v_vi + 1 ] - J_indptr [v_vi ], v_vj )
433+ T .reads ([J_indptr [0 : M + 1 ], J_indices [0 : nnz_j ], K0_indptr [0 : nnz_j + 1 ], K0_indices [0 : nnz_k ], K1_indptr [0 : nnz_j + 1 ], K1_indices [0 : nnz_k ], A_data [0 : nnz_k ], B_data [0 : M ]])
434+ T .writes ([B_data [0 : M ]])
435+ T .block_attr ({"sparse" :True })
436+ with T .init ():
437+ B_data [vi ] = T .float32 (0 )
438+ for v_vk in T .serial (0 , K1_indptr [J_indptr [v_vi ] + v_vj + 1 ] - K1_indptr [J_indptr [v_vi ] + v_vj ]):
439+ with T .block ("square_sum" ):
440+ vk = T .axis .reduce (K1_indptr [J_indptr [v_vi ] + v_vj + 1 ] - K1_indptr [J_indptr [v_vi ] + v_vj ], v_vk )
441+ T .reads ([J_indptr [0 : M + 1 ], J_indices [0 : nnz_j ], K0_indptr [0 : nnz_j + 1 ], K0_indices [0 : nnz_k ], K1_indptr [0 : nnz_j + 1 ], K1_indices [0 : nnz_k ], A_data [0 : nnz_k ], B_data [0 : M ]])
442+ T .writes ([B_data [0 : M ]])
443+ T .block_attr ({"sparse" :True })
444+ B_data [vi ] = B_data [vi ] + A_data [T .tvm_lower_bound (K0_indices .data , K1_indices [K1_indptr [J_indptr [vi ] + vj ] + vk ], K0_indptr [J_indptr [vi ] + vj ], K0_indptr [J_indptr [vi ] + vj + 1 ], dtype = "int32" )]
445+
446+
394447def test_csrmm ():
395448 mod = tvm .IRModule .from_expr (csrmm )
396449 mod = tvm .tir .transform .LowerSparseTIR ()(mod )
@@ -414,13 +467,15 @@ def test_csrmm():
414467 tvm .testing .assert_allclose (y_ground_truth .reshape (- 1 ), Y_nd .numpy (), rtol = 1e-5 , atol = 1e-5 )
415468
416469
470+ @pytest .mark .skip (reason = "Under implementation" )
417471def test_csrmm_dense_iter ():
418472 mod = tvm .IRModule .from_expr (csrmm_dense_iter )
419473 mod = tvm .tir .transform .LowerSparseTIR ()(mod )
420474 # tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True)
421475 # Todo
422476
423477
478+ @pytest .mark .skip (reason = "Under implementation" )
424479def test_segment_reduce ():
425480 mod = tvm .IRModule .from_expr (segment_reduce )
426481 mod = tvm .tir .transform .LowerSparseTIR ()(mod )
@@ -557,6 +612,7 @@ def test_csr_element_wise():
557612 tvm .testing .assert_allclose (b_ground_truth .data .reshape (- 1 ), B_nd .numpy (), rtol = 1e-5 , atol = 1e-5 )
558613
559614
615+ @pytest .mark .skip (reason = "Under implementation" )
560616def test_bmm ():
561617 mod = tvm .IRModule .from_expr (bmm )
562618 mod = tvm .tir .transform .LowerSparseTIR ()(mod )
@@ -600,6 +656,49 @@ def test_square_sum():
600656 tvm .testing .assert_allclose (b_ground_truth , B_data .numpy (), rtol = 1e-5 , atol = 1e-5 )
601657
602658
659+ def test_square_sum_two_K ():
660+ mod = tvm .IRModule .from_expr (square_sum_two_K )
661+ mod = tvm .tir .transform .LowerSparseTIR ()(mod )
662+ tvm .ir .assert_structural_equal (mod ["main" ], lowered_square_sum_two_K , True )
663+
664+ sch = tir .Schedule (mod , debug_mask = "all" )
665+ i , = sch .get_loops (sch .get_block ("square_sum_2" ))
666+ sch .bind (i , "threadIdx.x" )
667+
668+ density = 0.0125
669+ M = N1 = N2 = 128
670+ A_J = sp .random (M , N1 , dtype = "float32" , density = 1 - (1 - density ) ** N2 , format = "csr" )
671+ indptr_j = A_J .indptr
672+ indices_j = A_J .indices
673+ nnz_j = A_J .nnz
674+ A_K = sp .random (nnz_j , N2 , dtype = "float32" , density = density , format = "csr" )
675+ indptr_k = A_K .indptr
676+ indices_k = A_K .indices
677+ nnz_k = A_K .nnz
678+ data = A_K .data
679+
680+ b_ij = np .asarray (A_K .sum (axis = 1 )).squeeze ()
681+ A_J = sp .csr_matrix ((b_ij , indices_j , indptr_j ), shape = (M , N1 ))
682+ b_ground_truth = np .asarray (A_J .sum (axis = 1 )).squeeze ()
683+ b = np .zeros ((M ,)).astype ("float32" )
684+
685+ v_nnz_j , v_nnz_k , v_M , v_N1 , v_N2 = square_sum_two_K .params [- 5 :]
686+ f = tvm .build (sch .mod ["main" ].specialize ({v_nnz_j : nnz_j , v_nnz_k : nnz_k , v_M : M , v_N1 : N1 , v_N2 : N2 }), target = "cuda" )
687+
688+ ctx = tvm .device ("cuda" )
689+ A_data = tvm .nd .array (data .astype ("float32" ), device = ctx )
690+ A_indptr_j = tvm .nd .array (indptr_j .astype ("int32" ), device = ctx )
691+ A_indices_j = tvm .nd .array (indices_j .astype ("int32" ), device = ctx )
692+ A_indptr_k0 = tvm .nd .array (indptr_k .astype ("int32" ), device = ctx )
693+ A_indices_k0 = tvm .nd .array (indices_k .astype ("int32" ), device = ctx )
694+ A_indptr_k1 = tvm .nd .array (indptr_k .astype ("int32" ), device = ctx )
695+ A_indices_k1 = tvm .nd .array (indices_k .astype ("int32" ), device = ctx )
696+ B_data = tvm .nd .array (b .astype ("float32" ), device = ctx )
697+ f (A_data , B_data , A_indptr_j , A_indices_j , A_indptr_k0 , A_indices_k0 , A_indptr_k1 , A_indices_k1 )
698+
699+ tvm .testing .assert_allclose (b_ground_truth , B_data .numpy (), rtol = 1e-5 , atol = 1e-5 )
700+
701+
603702if __name__ == "__main__" :
604703 test_csrmm ()
605704 test_csrmm_dense_iter ()
@@ -610,3 +709,4 @@ def test_square_sum():
610709 test_csr_element_wise ()
611710 test_bmm ()
612711 test_square_sum ()
712+ test_square_sum_two_K ()
0 commit comments