Skip to content

Commit 4eeaf53

Browse files
committed
fix ci
1 parent 1b85cb6 commit 4eeaf53

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
389389
"""
390390
dtypeC = "bfloat16"
391391
B = torch_convert_bit_twiddling(qB)
392-
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)])
392+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
393393
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
394394
C = C.to(torch.__getattribute__(dtypeC))
395395
return C
@@ -412,7 +412,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
412412
"""
413413
dtypeC = "bfloat16"
414414
B = torch_convert_bit_twiddling(qB)
415-
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)])
415+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
416416
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
417417
C = C.to(torch.__getattribute__(dtypeC))
418418
return C
@@ -436,7 +436,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
436436
"""
437437
dtypeC = "bfloat16"
438438
B = torch_convert(qB)
439-
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)])
439+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
440440
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
441441
C = C.to(torch.__getattribute__(dtypeC))
442442
return C
@@ -464,7 +464,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
464464
"""
465465
dtypeC = "bfloat16"
466466
B = torch_convert(qB)
467-
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)])
467+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
468468
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
469469
C = C.to(torch.__getattribute__(dtypeC))
470470
return C

0 commit comments

Comments
 (0)