33import  tilelang .language  as  T 
44from  typing  import  Tuple , Optional 
55
6- 
76tilelang .set_log_level ("WARNING" )
87
98pass_configs  =  {
@@ -34,9 +33,7 @@ def fast_round_scale(amax, fp8_max_inv):
3433
3534
3635@tilelang .jit (pass_configs = pass_configs ) 
37- def  act_quant_kernel (
38-     N , in_dtype = BF16 , out_dtype = FP8 , scale_dtype = FP32 , round_scale = False 
39- ):
36+ def  act_quant_kernel (N , in_dtype = BF16 , out_dtype = FP8 , scale_dtype = FP32 , round_scale = False ):
4037    M  =  T .symbolic ("M" )
4138    fp8_min  =  - 448.0 
4239    fp8_max  =  448.0 
@@ -51,10 +48,11 @@ def act_quant_kernel_(
5148        Y : T .Tensor [(M , N ), out_dtype ],
5249        S : T .Tensor [(M , T .ceildiv (N , group_size )), scale_dtype ],
5350    ):
54-         with  T .Kernel (T .ceildiv (M , blk_m ), T .ceildiv (N , group_size ), threads = 128 ) as  (
55-             pid_m ,
56-             pid_n ,
57-         ):
51+         with  T .Kernel (
52+                 T .ceildiv (M , blk_m ), T .ceildiv (N , group_size ), threads = 128 ) as  (
53+                     pid_m ,
54+                     pid_n ,
55+                 ):
5856            x_shared  =  T .alloc_shared ((blk_m , group_size ), in_dtype )
5957            x_local  =  T .alloc_fragment ((blk_m , group_size ), in_dtype )
6058            amax_local  =  T .alloc_fragment ((blk_m ,), scale_dtype )
@@ -73,9 +71,7 @@ def act_quant_kernel_(
7371                    else :
7472                        s_local [i ] =  amax_local [i ] *  fp8_max_inv 
7573                for  i , j  in  T .Parallel (blk_m , group_size ):
76-                     y_local [i , j ] =  T .clamp (
77-                         x_local [i , j ] /  s_local [i ], fp8_min , fp8_max 
78-                     )
74+                     y_local [i , j ] =  T .clamp (x_local [i , j ] /  s_local [i ], fp8_min , fp8_max )
7975                for  i  in  T .Parallel (blk_m ):
8076                    S [pid_m  *  blk_m  +  i , pid_n ] =  s_local [i ]
8177                T .copy (y_local , y_shared )
@@ -84,9 +80,9 @@ def act_quant_kernel_(
8480    return  act_quant_kernel_ 
8581
8682
87- def  act_quant (
88-     x :  torch . Tensor ,  block_size :  int   =   128 ,  scale_fmt :  Optional [ str ]  =  None 
89- ) ->  Tuple [torch .Tensor , torch .Tensor ]:
83+ def  act_quant (x :  torch . Tensor , 
84+                block_size :  int  =  128 , 
85+                scale_fmt :  Optional [ str ]  =   None ) ->  Tuple [torch .Tensor , torch .Tensor ]:
9086    """ 
9187    Quantizes the input tensor `x` using block-wise quantization. 
9288
@@ -101,8 +97,7 @@ def act_quant(
10197    """ 
10298    assert  x .is_contiguous (), "Input tensor must be contiguous" 
10399    assert  x .size (- 1 ) %  block_size  ==  0 , (
104-         f"Last dimension size must be divisible by block_size (block_size={ block_size }  
105-     )
100+         f"Last dimension size must be divisible by block_size (block_size={ block_size }  )
106101    N  =  x .size (- 1 )
107102    y  =  torch .empty_like (x , dtype = torch .float8_e4m3fn )
108103    s  =  x .new_empty (* x .size ()[:- 1 ], N  //  block_size , dtype = torch .float32 )
@@ -129,10 +124,11 @@ def fp8_gemm_kernel_(
129124        scales_a : T .Tensor [(M , T .ceildiv (K , group_size )), FP32 ],
130125        scales_b : T .Tensor [(T .ceildiv (N , group_size ), T .ceildiv (K , group_size )), FP32 ],
131126    ):
132-         with  T .Kernel (T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = 128 ) as  (
133-             bx ,
134-             by ,
135-         ):
127+         with  T .Kernel (
128+                 T .ceildiv (N , block_N ), T .ceildiv (M , block_M ), threads = 128 ) as  (
129+                     bx ,
130+                     by ,
131+                 ):
136132            A_shared  =  T .alloc_shared ((block_M , block_K ), FP8 )
137133            B_shared  =  T .alloc_shared ((block_N , block_K ), FP8 )
138134            C_shared  =  T .alloc_shared ((block_M , block_N ), out_dtype )
@@ -168,9 +164,8 @@ def fp8_gemm_kernel_(
168164    return  fp8_gemm_kernel_ 
169165
170166
171- def  fp8_gemm (
172-     a : torch .Tensor , a_s : torch .Tensor , b : torch .Tensor , b_s : torch .Tensor 
173- ) ->  torch .Tensor :
167+ def  fp8_gemm (a : torch .Tensor , a_s : torch .Tensor , b : torch .Tensor ,
168+              b_s : torch .Tensor ) ->  torch .Tensor :
174169    """ 
175170    Perform a matrix multiplication using FP8 precision. 
176171
@@ -185,8 +180,7 @@ def fp8_gemm(
185180    """ 
186181    assert  a .is_contiguous () and  b .is_contiguous (), "Input tensors must be contiguous" 
187182    assert  a_s .is_contiguous () and  b_s .is_contiguous (), (
188-         "Scaling factor tensors must be contiguous" 
189-     )
183+         "Scaling factor tensors must be contiguous" )
190184    K  =  a .size (- 1 )
191185    M  =  a .numel () //  K 
192186    N  =  b .size (0 )
0 commit comments