11from  tilelang  import  tvm  as  tvm 
22import  tilelang  as  tl 
3- from  tilelang .utils .target  import  determine_target 
43import  tilelang .language  as  T 
54import  tilelang .testing 
65from  tvm  import  tir 
76
87tilelang .disable_cache ()
98
9+ 
1010def  test_inject_set_max_nreg ():
1111    """Test the InjectSetMaxNReg pass""" 
1212
@@ -37,21 +37,26 @@ def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16"
3737                    T .mbarrier_wait_parity (T .get_mbarrier (k  %  3  +  3 ), T .bitwise_xor (k  //  3  %  2 , 1 ))
3838                    if  v  -  128  ==  0 :
3939                        T .tma_load (
40-                             T .create_tma_descriptor (6 , 2 , A .data , 512 , 512 , 2 , 1024 , 32 , 64 , 1 , 1 , 0 , 2 , 2 , 0 ),
41-                             T .get_mbarrier (k  %  3 ),
42-                             T .tvm_access_ptr (T .type_annotation ("float16" ), A_shared .data , k  %  3  *  2048 , 2048 , 2 ),
40+                             T .create_tma_descriptor (6 , 2 , A .data , 512 , 512 , 2 , 1024 , 32 , 64 , 1 , 1 ,
41+                                                     0 , 2 , 2 , 0 ), T .get_mbarrier (k  %  3 ),
42+                             T .tvm_access_ptr (
43+                                 T .type_annotation ("float16" ), A_shared .data , k  %  3  *  2048 , 2048 , 2 ),
4344                            k  *  32 , by  *  64 )
44-                     T .evaluate (tir .Call ("handle" , "tir.ptx_arrive_barrier" , [T .get_mbarrier (k  %  3 )]))
45+                     T .evaluate (
46+                         tir .Call ("handle" , "tir.ptx_arrive_barrier" , [T .get_mbarrier (k  %  3 )]))
4547            else :
4648                # Consumer branch - should have set_max_nreg(240, 1) 
4749                for  k  in  range (16 ):
4850                    T .mbarrier_wait_parity (T .get_mbarrier (k  %  3 ), k  //  3  %  2 )
4951                    T .call_extern (
5052                        "handle" , "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>" ,
51-                         T .tvm_access_ptr (T .type_annotation ("float16" ), A_shared .data , k  %  3  *  2048 , 2048 , 1 ),
52-                         T .tvm_access_ptr (T .type_annotation ("float16" ), B_shared .data , k  %  3  *  2048 , 2048 , 1 ),
53+                         T .tvm_access_ptr (
54+                             T .type_annotation ("float16" ), A_shared .data , k  %  3  *  2048 , 2048 , 1 ),
55+                         T .tvm_access_ptr (
56+                             T .type_annotation ("float16" ), B_shared .data , k  %  3  *  2048 , 2048 , 1 ),
5357                        T .tvm_access_ptr (T .type_annotation ("float32" ), C_local .data , 0 , 32 , 3 ))
54-                     T .evaluate (tir .Call ("handle" , "tir.ptx_arrive_barrier" , [T .get_mbarrier (k  %  3  +  3 )]))
58+                     T .evaluate (
59+                         tir .Call ("handle" , "tir.ptx_arrive_barrier" , [T .get_mbarrier (k  %  3  +  3 )]))
5560
5661    # Apply the InjectSetMaxNReg pass 
5762    func  =  before 
@@ -64,15 +69,15 @@ def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16"
6469    set_max_nreg_calls  =  []
6570
6671    def  collect_set_max_nreg (stmt ):
67-         if  isinstance (stmt , tvm .tir .Evaluate ):
68-             if  hasattr (stmt .value , 'op' ) and  hasattr (stmt .value .op , 'name' ):
69-                 if  stmt .value .op .name  ==  "tl.set_max_nreg" :
70-                     set_max_nreg_calls .append ((stmt .value .args [0 ].value , stmt .value .args [1 ].value ))
72+         if  (isinstance (stmt , tvm .tir .Evaluate ) and  hasattr (stmt .value , 'op' ) and 
73+                 hasattr (stmt .value .op , 'name' ) and  stmt .value .op .name  ==  "tl.set_max_nreg" ):
74+             set_max_nreg_calls .append (stmt .value )
7175
7276    tvm .tir .stmt_functor .post_order_visit (main_func .body , collect_set_max_nreg )
7377
7478    # We should have at least 2 set_max_nreg calls (one for producer, one for consumer) 
75-     assert  len (set_max_nreg_calls ) >=  2 , f"Expected at least 2 set_max_nreg calls, got { len (set_max_nreg_calls )}  
79+     assert  len (set_max_nreg_calls 
80+               ) >=  2 , f"Expected at least 2 set_max_nreg calls, got { len (set_max_nreg_calls )}  
7681
7782    # Check that we have the expected register values 
7883    reg_values  =  [call [0 ] for  call  in  set_max_nreg_calls ]
@@ -118,15 +123,16 @@ def before_no_set_max_nreg(A: T.Tensor((512, 512), "float16")):
118123    set_max_nreg_calls  =  []
119124
120125    def  collect_set_max_nreg (stmt ):
121-         if  isinstance (stmt , tvm .tir .Evaluate ):
122-             if  hasattr (stmt .value , 'op' ) and  hasattr (stmt .value .op , 'name' ):
123-                 if  stmt .value .op .name  ==  "tl.set_max_nreg" :
124-                     set_max_nreg_calls .append (stmt .value )
126+         if  (isinstance (stmt , tvm .tir .Evaluate ) and  hasattr (stmt .value , 'op' ) and 
127+                 hasattr (stmt .value .op , 'name' ) and  stmt .value .op .name  ==  "tl.set_max_nreg" ):
128+             set_max_nreg_calls .append (stmt .value )
125129
126130    tvm .tir .stmt_functor .post_order_visit (main_func .body , collect_set_max_nreg )
127131
128132    # Should have no set_max_nreg calls when no_set_max_nreg is present 
129-     assert  len (set_max_nreg_calls ) ==  0 , f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got { len (set_max_nreg_calls )}  
133+     assert  len (
134+         set_max_nreg_calls 
135+     ) ==  0 , f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got { len (set_max_nreg_calls )}  
130136
131137    print ("InjectSetMaxNReg with no_set_max_nreg test passed!" )
132138
0 commit comments