@@ -40,8 +40,8 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
4040    # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 
4141    e_bf16  =  e_f4  +  tir .const (126 , "uint16" )
4242    # Scale is the exponential part, within the representation of uint8 
43-     # To handle the overflow, we use the max  function to limit the exponential part to 8 bits 
44-     e_bf16  =  T .min (e_bf16  +  scale , tir .const ((1  <<  8 ) -  1 , "uint16" ))
43+     # To handle the overflow, we may  use the min  function to limit the exponential part to 8 bits 
44+     #  e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
4545    m_f4  =  f4  &  tir .const (1 , "uint16" )
4646    val_bf16  =  tir .reinterpret ("bfloat16" ,
4747                               ((((s  <<  tir .const (8 , "uint16" )) |  e_bf16 ) <<  tir .const (7 , "uint16" ))
@@ -218,7 +218,7 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k):
218218            B_local_thread  =  T .alloc_local ((local_compress_size ,), storage_dtype )
219219            B_dequantize_local_thread  =  T .alloc_local ((local_size ,), out_dtype )
220220            Scale_local_thread  =  T .alloc_local ((1 ,), storage_dtype )
221-             Scale_local_thread_exponent  =  T .alloc_local ((1 ,), "float32" )
221+             Scale_local_thread_exponent  =  T .alloc_local ((1 ,), out_dtype )
222222
223223            for  i  in  T .serial (0 , block_N  *  block_K  //  threads  //  local_size ):
224224                # First, load data from share memory to register. 
@@ -231,8 +231,7 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k):
231231                si  =  index_scale  //  (block_K  //  scale_size )
232232                sj  =  index_scale  %  (block_K  //  scale_size )
233233                Scale_local_thread [0 ] =  Scale [bx  *  block_N  +  si , k  *  block_K  //  scale_size  +  sj ]
234-                 Scale_local_thread_exponent [0 ] =  T .exp2 (
235-                     T .cast (Scale_local_thread [0 ] -  127 , "float" ))
234+                 Scale_local_thread_exponent [0 ] =  T .shift_left (1 , (Scale_local_thread [0 ]))
236235
237236                # Then, dequant. 
238237                T .call_extern (
@@ -288,7 +287,7 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
288287            - Mutates B_dequantize_shared by storing the dequantized BF16 fragment. 
289288            """ 
290289            B_local  =  T .alloc_fragment (B_shared_shape , storage_dtype )
291-             B_dequantize_local  =  T .alloc_fragment (B_dequantize_shared_shape , in_dtype )
290+             B_dequantize_local  =  T .alloc_fragment (B_dequantize_shared_shape , out_dtype )
292291
293292            bx  =  T .get_block_binding (0 )
294293            T .copy (B_shared , B_local )
@@ -300,8 +299,9 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
300299                    Scale [
301300                        bx  *  block_N  +  i , k  *  block_K  //  scale_size  +  j  // 
302301                        scale_size ],  # Scale is the exponential part, within the representation of uint8 
303-                     dtype = in_dtype ,
304-                 )
302+                     dtype = out_dtype ,
303+                 ) *  T .shift_left (
304+                     1 , (Scale [bx  *  block_N  +  i , k  *  block_K  //  scale_size  +  j  //  scale_size ]))
305305            T .copy (B_dequantize_local , B_dequantize_shared )
306306
307307        return  simple_dequant_bf16_fp4 
@@ -374,7 +374,7 @@ def ref_program_twiddling(A, qB, Scale):
374374    B  =  torch_convert_bit_twiddling (qB )
375375    for  i  in  range (B .shape [0 ]):
376376        for  j  in  range (B .shape [1 ]):
377-             B [i ][j ] =  B [i ][j ] *  (2 ** (Scale [i ][j  //  32 ]  -   127 ))
377+             B [i ][j ] =  B [i ][j ] *  (2 ** (Scale [i ][j  //  32 ]))
378378    C  =  torch .matmul (A .to (torch .float ), B .T .to (torch .float ))
379379    C  =  C .to (torch .__getattribute__ (dtypeC ))
380380    return  C 
@@ -400,7 +400,7 @@ def ref_program_simple(A, qB, Scale):
400400    B  =  torch_convert (qB )
401401    for  i  in  range (B .shape [0 ]):
402402        for  j  in  range (B .shape [1 ]):
403-             B [i ][j ] =  B [i ][j ] *  (2 ** (Scale [i ][j  //  32 ]  -   127 ))
403+             B [i ][j ] =  B [i ][j ] *  (2 ** (Scale [i ][j  //  32 ]))
404404    C  =  torch .matmul (A .to (torch .float ), B .T .to (torch .float ))
405405    C  =  C .to (torch .__getattribute__ (dtypeC ))
406406    return  C 
@@ -427,7 +427,15 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
427427
428428    if  tune :
429429        kernel  =  matmul (
430-             m , n , k , "bfloat16" , "bfloat16" , "float32" , num_bits = 4 , scale_size = scale_size )
430+             m ,
431+             n ,
432+             k ,
433+             "bfloat16" ,
434+             "bfloat16" ,
435+             "float32" ,
436+             num_bits = 4 ,
437+             scale_size = scale_size ,
438+             fast_dequant = fast_dequant )
431439    else :
432440        kernel  =  matmul (
433441            m ,
@@ -443,7 +451,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
443451            block_K = 128 ,
444452            num_stages = 2 ,
445453            threads = 256 ,
446-             split = 1 )
454+             split = 1 ,
455+             fast_dequant = fast_dequant )
447456
448457    profiler  =  kernel .get_profiler (tilelang .TensorSupplyType .Auto )
449458
0 commit comments