@@ -69,9 +69,7 @@ def matmul(M,
6969           scale_size = 32 ,
7070           tune = False ):
7171
72-     @tilelang .jit ( 
73-         out_idx = [- 1 ], 
74-     ) 
72+     @tilelang .jit (out_idx = [- 1 ],) 
7573    def  kernel_func (block_M , block_N , block_K , num_stages , threads , split = 1 ):
7674        num_elems_per_byte  =  8  //  num_bits 
7775        storage_dtype  =  "uint8" 
@@ -81,7 +79,6 @@ def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
8179        A_shared_shape  =  (block_M , block_K )
8280        B_shared_shape  =  (block_N , block_K  //  num_elems_per_byte )
8381        B_dequantize_shared_shape  =  (block_N , block_K )
84-         Scale_shared_shape  =  (block_N , block_K  //  scale_size )
8582        assert  K  %  (block_K  *  split ) ==  0 
8683
8784        # Some variables for serial dequant in each thread 
@@ -121,7 +118,6 @@ def main(
121118                B_local_thread  =  T .alloc_local ((local_compress_size ,), storage_dtype )
122119                B_dequantize_local_thread  =  T .alloc_local ((local_size ,), in_dtype )
123120                B_dequantize_shared  =  T .alloc_shared (B_dequantize_shared_shape , in_dtype )
124-                 Scale_shared  =  T .alloc_shared (Scale_shared_shape , storage_dtype )
125121                Scale_local_thread  =  T .alloc_local ((1 ,), storage_dtype )
126122                Scale_local_thread_exponent  =  T .alloc_local ((1 ,), "float32" )
127123
@@ -158,7 +154,8 @@ def main(
158154                        index_scale  =  index_base  //  (scale_size  //  num_elems_per_byte )
159155                        si  =  index_scale  //  (block_K  //  scale_size )
160156                        sj  =  index_scale  %  (block_K  //  scale_size )
161-                         Scale_local_thread [0 ] =  Scale [bx  *  block_N  +  si , k  *  block_K  //  scale_size  +  sj ]
157+                         Scale_local_thread [0 ] =  Scale [bx  *  block_N  +  si ,
158+                                                       k  *  block_K  //  scale_size  +  sj ]
162159                        Scale_local_thread_exponent [0 ] =  T .exp2 (
163160                            T .cast (Scale_local_thread [0 ] -  127 , "float" ))
164161
0 commit comments