- 
                Notifications
    You must be signed in to change notification settings 
- Fork 290
[MXFP4] Fix bugs and optimize exponential operation #750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -40,8 +40,8 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| e_bf16 = e_f4 + tir.const(126, "uint16") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Scale is the exponential part, within the representation of uint8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # To handle the overflow, we use the max function to limit the exponential part to 8 bits | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # To handle the overflow, we may use the min function to limit the exponential part to 8 bits | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| m_f4 = f4 & tir.const(1, "uint16") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| val_bf16 = tir.reinterpret("bfloat16", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -218,7 +218,7 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scale_local_thread = T.alloc_local((1,), storage_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scale_local_thread_exponent = T.alloc_local((1,), "float32") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scale_local_thread_exponent = T.alloc_local((1,), out_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in T.serial(0, block_N * block_K // threads // local_size): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # First, load data from share memory to register. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -231,8 +231,7 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| si = index_scale // (block_K // scale_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sj = index_scale % (block_K // scale_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scale_local_thread_exponent[0] = T.exp2( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.cast(Scale_local_thread[0] - 127, "float")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      +234
     to 
      235
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion T.shift_left(1, Scale) returns an integer, not BF16; negative exponents and overflow aren’t handled. 
 Robust fix: build the exact BF16 2^scale by constructing the BF16 bit-pattern (handles negative exponents and avoids integer overflow). Replace the assignment with: -                Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
+                # Treat stored scale as signed int8 exponent offset from 0, then form a BF16 value 2^scale.
+                scale_i8 = tir.reinterpret("int8", Scale_local_thread[0])
+                e_i16 = T.cast(scale_i8, "int16") + tir.const(127, "int16")   # BF16 bias
+                # Clamp to avoid subnormals/inf; adjust limits as desired (0..254 keeps finite, non-subnormal range)
+                e_i16 = T.min(T.max(e_i16, tir.const(0, "int16")), tir.const(254, "int16"))
+                e_u16 = T.cast(e_i16, "uint16")
+                Scale_local_thread_exponent[0] = tir.reinterpret(
+                    out_dtype,
+                    T.shift_left(e_u16, tir.const(7, "uint16")),  # mantissa = 0
+                )If you truly want the shift approach for performance and can guarantee non-negative, small exponents, at minimum add explicit casts to satisfy the type-checker and document the precondition: -                Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
+                # Precondition: Scale in [0, 30]
+                Scale_local_thread_exponent[0] = T.cast(
+                    T.shift_left(tir.const(1, "int32"), T.cast(Scale_local_thread[0], "int32")),
+                    out_dtype,
+                )📝 Committable suggestion
 
        Suggested change
       
 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Then, dequant. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.call_extern( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -288,7 +287,7 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - Mutates B_dequantize_shared by storing the dequantized BF16 fragment. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_local = T.alloc_fragment(B_shared_shape, storage_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bx = T.get_block_binding(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(B_shared, B_local) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -300,8 +299,9 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Scale[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bx * block_N + i, k * block_K // scale_size + j // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_size], # Scale is the exponential part, within the representation of uint8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=in_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=out_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) * T.shift_left( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      299
     to 
      +304
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The expression to access the scale value is duplicated. To improve readability and avoid redundant calculations, you can store the scale value in a local variable before using it. For example: scale_val = Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
    num_bits,
    B_local[i, j // num_elems_per_byte],
    j % num_elems_per_byte,
    scale_val,  # This argument is unused and can be removed if the function is refactored.
    dtype=out_dtype,
) * T.shift_left(1, scale_val) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.copy(B_dequantize_local, B_dequantize_shared) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 
      Comment on lines
    
      300
     to 
      305
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Simple path repeats the same issues: int shift typedness, negative exponents, and overflow. 
 Apply the same bit-pattern approach inline to form a BF16  -                ) * T.shift_left(
-                    1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
+                ) * tir.reinterpret(
+                    out_dtype,
+                    T.shift_left(
+                        T.cast(
+                            T.min(
+                                T.max(
+                                    T.cast(
+                                        tir.reinterpret(
+                                            "int8",
+                                            Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size],
+                                        ),
+                                        "int16",
+                                    ) + tir.const(127, "int16"),
+                                    tir.const(0, "int16"),
+                                ),
+                                tir.const(254, "int16"),
+                            ),
+                            "uint16",
+                        ),
+                        tir.const(7, "uint16"),
+                    ),
+                )If you keep the shift optimization with restricted Scale ranges, use explicit casts and add an assert: +                T.Assert(
+                    (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size] <= tir.const(30, "uint8")),
+                    "Scale must be <= 30 when using integer shift.",
+                )
-                ) * T.shift_left(
-                    1, (Scale[...]))
+                ) * T.cast(
+                    T.shift_left(
+                        tir.const(1, "int32"),
+                        T.cast(Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size], "int32"),
+                    ),
+                    out_dtype,
+                )📝 Committable suggestion
 
        Suggested change
       
 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return simple_dequant_bf16_fp4 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -374,7 +374,7 @@ def ref_program_twiddling(A, qB, Scale): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B = torch_convert_bit_twiddling(qB) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in range(B.shape[0]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for j in range(B.shape[1]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C = C.to(torch.__getattribute__(dtypeC)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return C | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -400,7 +400,7 @@ def ref_program_simple(A, qB, Scale): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B = torch_convert(qB) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in range(B.shape[0]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for j in range(B.shape[1]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| C = C.to(torch.__getattribute__(dtypeC)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return C | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -427,7 +427,15 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if tune: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = matmul( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| m, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "bfloat16", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "bfloat16", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "float32", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_bits=4, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_size=scale_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fast_dequant=fast_dequant) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = matmul( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| m, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -443,7 +451,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_K=128, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_stages=2, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| threads=256, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| split=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| split=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fast_dequant=fast_dequant) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the
scaleis now applied externally, this commented-out code is dead and can be removed to improve clarity. Thescaleparameter in the function signature (line 10) also becomes unused and could be removed in a follow-up refactoring.