@@ -121,7 +121,7 @@ def matmul(M,
121121 num_stages (int, optional): pipelining stages for K loop (default 2).
122122 threads (int, optional): threads per block used by the kernel (default 256).
123123 split (int, optional): split factor along K used by the scheduler (default 1).
124- with_bias (bool, optional): whether to add bias to the output (default False).
124+ with_bias (bool, optional): whether to add Bias to the output (default False).
125125
126126 Returns:
127127 A T.prim_func implementing the tiled, pipelined GEMM that:
@@ -141,11 +141,11 @@ def matmul(M,
141141 Block_QK = block_K // num_elems_per_byte
142142 A_shape = (M , K )
143143 B_shape = (N , QK )
144- bias_shape = (M , N )
144+ Bias_shape = (M , N )
145145 Scale_shape = (N , K // scale_size )
146146 A_shared_shape = (block_M , block_K )
147147 B_shared_shape = (block_N , Block_QK )
148- bias_shared_shape = (block_M , block_N )
148+ Bias_shared_shape = (block_M , block_N )
149149 B_dequantize_shared_shape = (block_N , block_K )
150150 assert K % (block_K * split ) == 0
151151
@@ -315,7 +315,7 @@ def main(
315315 A : T .Tensor (A_shape , in_dtype ),
316316 B : T .Tensor (B_shape , storage_dtype ),
317317 Scale : T .Tensor (Scale_shape , storage_dtype ),
318- bias : T .Tensor (bias_shape , out_dtype ),
318+ Bias : T .Tensor (Bias_shape , out_dtype ),
319319 C : T .Tensor ((M , N ), out_dtype ),
320320 ):
321321 """
@@ -333,7 +333,7 @@ def main(
333333 A_shared = T .alloc_shared (A_shared_shape , in_dtype )
334334 B_shared = T .alloc_shared (B_shared_shape , storage_dtype )
335335 B_dequantize_shared = T .alloc_shared (B_dequantize_shared_shape , in_dtype )
336- bias_shared = T .alloc_shared (bias_shared_shape , out_dtype )
336+ Bias_shared = T .alloc_shared (Bias_shared_shape , out_dtype )
337337 C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
338338 C_shared = T .alloc_shared ((block_M , block_N ), out_dtype )
339339
@@ -345,16 +345,16 @@ def main(
345345
346346 if with_bias :
347347 T .annotate_layout ({
348- bias_shared : tilelang .layout .make_swizzled_layout (bias_shared ),
348+ Bias_shared : tilelang .layout .make_swizzled_layout (Bias_shared ),
349349 })
350350
351351 if threads == 512 :
352352 T .disable_warp_group_reg_alloc ()
353353
354354 if with_bias :
355- T .copy (bias [by * block_M :(by + 1 ) * block_M , bx * block_N :(bx + 1 ) * block_N ],
356- bias_shared )
357- T .copy (bias_shared , C_local )
355+ T .copy (Bias [by * block_M :(by + 1 ) * block_M , bx * block_N :(bx + 1 ) * block_N ],
356+ Bias_shared )
357+ T .copy (Bias_shared , C_local )
358358 else :
359359 T .clear (C_local )
360360
@@ -373,7 +373,7 @@ def main(
373373 return main
374374
375375
376- def ref_program_twiddling (A , qB , Scale , bias = None ):
376+ def ref_program_twiddling (A , qB , Scale , Bias = None ):
377377 """
378378 Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
379379
@@ -397,7 +397,7 @@ def ref_program_twiddling(A, qB, Scale, bias=None):
397397 return C
398398
399399
400- def ref_program_twiddling_with_bias (A , qB , Scale , bias ):
400+ def ref_program_twiddling_with_bias (A , qB , Scale , Bias ):
401401 """
402402 Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
403403
@@ -407,7 +407,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, bias):
407407 A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
408408 qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
409409 Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
410- bias (torch.Tensor): Bias tensor with shape (M, N).
410+ Bias (torch.Tensor): Bias tensor with shape (M, N).
411411
412412 Returns:
413413 torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
@@ -417,12 +417,12 @@ def ref_program_twiddling_with_bias(A, qB, Scale, bias):
417417 for i in range (B .shape [0 ]):
418418 for j in range (B .shape [1 ]):
419419 B [i ][j ] = B [i ][j ] * (2 ** (Scale [i ][j // 32 ]))
420- C = torch .matmul (A .to (torch .float ), B .T .to (torch .float )) + bias
420+ C = torch .matmul (A .to (torch .float ), B .T .to (torch .float )) + Bias
421421 C = C .to (torch .__getattribute__ (dtypeC ))
422422 return C
423423
424424
425- def ref_program_simple (A , qB , Scale , bias = None ):
425+ def ref_program_simple (A , qB , Scale , Bias = None ):
426426 """
427427 Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
428428
@@ -448,7 +448,7 @@ def ref_program_simple(A, qB, Scale, bias=None):
448448 return C
449449
450450
451- def ref_program_simple_with_bias (A , qB , Scale , bias ):
451+ def ref_program_simple_with_bias (A , qB , Scale , Bias ):
452452 """
453453 Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
454454
@@ -460,7 +460,7 @@ def ref_program_simple_with_bias(A, qB, Scale, bias):
460460 - A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
461461 - qB: Quantized representation of B accepted by `torch_convert`.
462462 - Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
463- - bias : 2D tensor representing the bias (will be cast to float32 for the matmul).
463+ - Bias : 2D tensor representing the Bias (will be cast to float32 for the matmul).
464464
465465
466466 Returns:
@@ -473,7 +473,7 @@ def ref_program_simple_with_bias(A, qB, Scale, bias):
473473 for i in range (B .shape [0 ]):
474474 for j in range (B .shape [1 ]):
475475 B [i ][j ] = B [i ][j ] * (2 ** (Scale [i ][j // 32 ]))
476- C = torch .matmul (A .to (torch .float ), B .T .to (torch .float )) + bias
476+ C = torch .matmul (A .to (torch .float ), B .T .to (torch .float )) + Bias
477477 C = C .to (torch .__getattribute__ (dtypeC ))
478478 return C
479479
0 commit comments