You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Enhancement] Add a MXFP4 grouped GEMM example for FusedMoE (#811)
* [Enhancement] Enhance dequantization examples and utilities
- Added a new example for grouped matrix multiplication with experts in `example_dequant_groupgemm_bf16_mxfp4_hopper.py`.
- Improved dequantization logic in existing examples by replacing nested loops with vectorized operations for better performance.
- Updated `torch_convert_bit_twiddling` function in `utils.py` to utilize parallel processing, enhancing efficiency and clarity in the conversion process.
Co-authored-by: Zhengju Tang <97930865+tzj-fxz@users.noreply.github.com>
* fix typos in docstrings
* remove redundant code
* [Format] Unreproducible debug with T.print
* [BugFix] Correct dtype in ref dequantize; larger data distribution
* [Format]
* [Refactor] Clean up and optimize example_dequant_groupgemm_bf16_mxfp4_hopper.py and utils.py
- Removed unnecessary cache disabling and manual seed setting in the example.
- Simplified nested loops into parallelized operations for better readability and performance.
- Updated the assertion function in utils.py to print detailed error messages.
- Adjusted tensor sizes in examples
* [Refactor] Update import path in example_dequant_gemm_fine_grained.py
- Changed the import statement for `_tir_packed_to_unsigned_convert` from `bitblas.quantization` to `tilelang.quantize` to reflect the new module structure.
* lint
* rename and add test
* lint
* [Feature] Enhance autotuning and configuration generation in example_dequant_groupedgemm_bf16_mxfp4_hopper.py
- Added a new function `get_configs()` to generate hyperparameter configurations for tuning.
- Updated the `matmul` function to utilize autotuning with the new configurations.
- Improve kernel performance via vectorization and threadblock swizzle.
- Enhanced the main function to support the new autotuning inputs and updated parameters for better performance.
* lint
* fix typo
* fix typo and lint
* make ci format check happy
* fix ci
---------
Co-authored-by: Zhengju Tang <97930865+tzj-fxz@users.noreply.github.com>
Co-authored-by: tzj-fxz <tzjfxz@gmail.com>
Copy file name to clipboardExpand all lines: examples/dequantize_gemm/utils.py
+76-32Lines changed: 76 additions & 32 deletions
Original file line number
Diff line number
Diff line change
@@ -3,8 +3,6 @@
3
3
4
4
deftorch_convert_bit_twiddling(tensor):
5
5
"""
6
-
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.
7
-
8
6
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
0 commit comments