Skip to content

Commit be2f79a

Browse files
committed
lint
1 parent a3ca871 commit be2f79a

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
def get_configs():
1212
"""
1313
Generate a list of hyperparameter configuration dictionaries for tuning.
14-
14+
1515
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
1616
'num_stages', 'threads', and 'split'. The function returns the Cartesian
1717
product of the parameter value lists:
1818
- block_M, block_N, block_K: tiling sizes
1919
- num_stages: pipeline stages
2020
- threads: thread counts
2121
- split: K-splitting factor
22-
22+
2323
Returns:
2424
List[dict]: A list of configuration dictionaries covering all combinations.
2525
"""
@@ -309,17 +309,20 @@ def main(
309309
C_local[i, j] = Bias_shared[j]
310310

311311
tx = T.get_thread_binding()
312-
312+
313313
for k in T.Pipelined(K // block_K, num_stages=num_stages):
314314
for copy_i in T.serial(block_M * block_K // threads // 16):
315315
base = copy_i * threads * 16 + tx * 16
316316
if sorted_token_ids_shared[base // block_K] != -1:
317317
for copy_j in T.vectorized(16):
318-
A_shared[base // block_K, base % block_K + copy_j] = A[sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j]
318+
A_shared[base // block_K, base % block_K +
319+
copy_j] = A[sorted_token_ids_shared[base // block_K] // topk,
320+
k * block_K + base % block_K + copy_j]
319321

320322
T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared)
321323
if fast_dequant:
322-
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k)
324+
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared,
325+
k)
323326
else:
324327
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
325328

@@ -331,7 +334,7 @@ def main(
331334
T.copy(C_local, C_shared)
332335
for i, j in T.Parallel(block_M, block_N):
333336
C[sorted_token_ids_shared[i] // topk, sorted_token_ids_shared[i] % topk,
334-
bx * block_N + j] = C_shared[i, j]
337+
bx * block_N + j] = C_shared[i, j]
335338

336339
return main
337340

@@ -366,7 +369,8 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
366369

367370
# Compute the output for this token-expert pair
368371
# token_embedding @ B.T + bias
369-
output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id]
372+
output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(
373+
torch.bfloat16)) + Bias[expert_id]
370374
output = output.to(torch.__getattribute__(dtypeC))
371375

372376
# Apply the topk weight
@@ -491,7 +495,9 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
491495
max_val = diff.max()
492496
max_idx = diff.argmax()
493497
print(f"max abs diff: {max_val} at index: {max_idx}")
494-
assert_similar(output, ref_output, name="output", eps=1e-5) # We care about the similarity rather than abs. difference
498+
assert_similar(
499+
output, ref_output, name="output",
500+
eps=1e-5) # We care about the similarity rather than abs. difference
495501
print("All checks pass. ✅")
496502

497503

examples/dequantize_gemm/test_example_dequantize_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():
3535
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
3636
def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
3737
example_dequant_groupedgemm_bf16_mxfp4_hopper.main()
38-
39-
38+
39+
4040
@tilelang.testing.requires_cuda
4141
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
4242
def test_example_dequant_gemm_w4a8():

0 commit comments

Comments
 (0)