Skip to content

Commit a7fe6ee

Browse files
committed
[Lint] Rename "bias" to "Bias"
1 parent c07d8f8 commit a7fe6ee

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def matmul(M,
121121
num_stages (int, optional): pipelining stages for K loop (default 2).
122122
threads (int, optional): threads per block used by the kernel (default 256).
123123
split (int, optional): split factor along K used by the scheduler (default 1).
124-
with_bias (bool, optional): whether to add bias to the output (default False).
124+
with_bias (bool, optional): whether to add Bias to the output (default False).
125125
126126
Returns:
127127
A T.prim_func implementing the tiled, pipelined GEMM that:
@@ -141,11 +141,11 @@ def matmul(M,
141141
Block_QK = block_K // num_elems_per_byte
142142
A_shape = (M, K)
143143
B_shape = (N, QK)
144-
bias_shape = (M, N)
144+
Bias_shape = (M, N)
145145
Scale_shape = (N, K // scale_size)
146146
A_shared_shape = (block_M, block_K)
147147
B_shared_shape = (block_N, Block_QK)
148-
bias_shared_shape = (block_M, block_N)
148+
Bias_shared_shape = (block_M, block_N)
149149
B_dequantize_shared_shape = (block_N, block_K)
150150
assert K % (block_K * split) == 0
151151

@@ -315,7 +315,7 @@ def main(
315315
A: T.Tensor(A_shape, in_dtype),
316316
B: T.Tensor(B_shape, storage_dtype),
317317
Scale: T.Tensor(Scale_shape, storage_dtype),
318-
bias: T.Tensor(bias_shape, out_dtype),
318+
Bias: T.Tensor(Bias_shape, out_dtype),
319319
C: T.Tensor((M, N), out_dtype),
320320
):
321321
"""
@@ -333,7 +333,7 @@ def main(
333333
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
334334
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
335335
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
336-
bias_shared = T.alloc_shared(bias_shared_shape, out_dtype)
336+
Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype)
337337
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
338338
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
339339

@@ -345,16 +345,16 @@ def main(
345345

346346
if with_bias:
347347
T.annotate_layout({
348-
bias_shared: tilelang.layout.make_swizzled_layout(bias_shared),
348+
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
349349
})
350350

351351
if threads == 512:
352352
T.disable_warp_group_reg_alloc()
353353

354354
if with_bias:
355-
T.copy(bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
356-
bias_shared)
357-
T.copy(bias_shared, C_local)
355+
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
356+
Bias_shared)
357+
T.copy(Bias_shared, C_local)
358358
else:
359359
T.clear(C_local)
360360

@@ -373,7 +373,7 @@ def main(
373373
return main
374374

375375

376-
def ref_program_twiddling(A, qB, Scale, bias=None):
376+
def ref_program_twiddling(A, qB, Scale, Bias=None):
377377
"""
378378
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
379379
@@ -397,7 +397,7 @@ def ref_program_twiddling(A, qB, Scale, bias=None):
397397
return C
398398

399399

400-
def ref_program_twiddling_with_bias(A, qB, Scale, bias):
400+
def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
401401
"""
402402
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
403403
@@ -407,7 +407,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, bias):
407407
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
408408
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
409409
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
410-
bias (torch.Tensor): Bias tensor with shape (M, N).
410+
Bias (torch.Tensor): Bias tensor with shape (M, N).
411411
412412
Returns:
413413
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
@@ -417,12 +417,12 @@ def ref_program_twiddling_with_bias(A, qB, Scale, bias):
417417
for i in range(B.shape[0]):
418418
for j in range(B.shape[1]):
419419
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
420-
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + bias
420+
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
421421
C = C.to(torch.__getattribute__(dtypeC))
422422
return C
423423

424424

425-
def ref_program_simple(A, qB, Scale, bias=None):
425+
def ref_program_simple(A, qB, Scale, Bias=None):
426426
"""
427427
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
428428
@@ -448,7 +448,7 @@ def ref_program_simple(A, qB, Scale, bias=None):
448448
return C
449449

450450

451-
def ref_program_simple_with_bias(A, qB, Scale, bias):
451+
def ref_program_simple_with_bias(A, qB, Scale, Bias):
452452
"""
453453
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
454454
@@ -460,7 +460,7 @@ def ref_program_simple_with_bias(A, qB, Scale, bias):
460460
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
461461
- qB: Quantized representation of B accepted by `torch_convert`.
462462
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
463-
- bias: 2D tensor representing the bias (will be cast to float32 for the matmul).
463+
- Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul).
464464
465465
466466
Returns:
@@ -473,7 +473,7 @@ def ref_program_simple_with_bias(A, qB, Scale, bias):
473473
for i in range(B.shape[0]):
474474
for j in range(B.shape[1]):
475475
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
476-
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + bias
476+
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
477477
C = C.to(torch.__getattribute__(dtypeC))
478478
return C
479479

0 commit comments

Comments
 (0)