Skip to content

Commit 5a1a9df

Browse files
📝 Add docstrings to mxfp4
Docstrings generation was requested by @LeiWang1999. * #725 (comment) The following files were modified: * `examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py` * `examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py` * `examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py` * `examples/dequantize_gemm/utils.py` * `examples/gemm/example_gemm_autotune.py` * `tilelang/intrinsics/utils.py` * `tilelang/language/__init__.py` * `tilelang/language/utils.py` * `tilelang/quantize/mxfp.py` * `tilelang/quantize/quantization.py`
1 parent 24603e4 commit 5a1a9df

File tree

10 files changed

+628
-41
lines changed

10 files changed

+628
-41
lines changed

examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,39 @@ def bitnet_158_int8xint2_prefill(
8282
warp_col_tiles=32,
8383
chunk=64,
8484
):
85+
"""
86+
Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.
87+
88+
The returned prim_func expects:
89+
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8").
90+
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
91+
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32").
92+
93+
Details:
94+
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
95+
- Tiling parameters:
96+
- block_row_warps, block_col_warps: number of warps per block in row/col.
97+
- warp_row_tiles, warp_col_tiles: tiles per warp.
98+
- chunk: K-sized chunk per block (block_K).
99+
- micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32").
100+
- Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior.
101+
- Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values.
102+
103+
Parameters:
104+
M, N, K (int): Global matrix dimensions.
105+
in_dtype (str): Input and decoded B element dtype; "float16" or "int8".
106+
out_dtype (str): Output C dtype; one of "float16", "float32", "int32".
107+
accum_dtype (str): Accumulator dtype used by MMA (e.g., "int32").
108+
fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used).
109+
block_row_warps (int): Warps in block row dimension.
110+
block_col_warps (int): Warps in block column dimension.
111+
warp_row_tiles (int): Tiles per warp in row dimension.
112+
warp_col_tiles (int): Tiles per warp in column dimension.
113+
chunk (int): K-length per block (block_K).
114+
115+
Returns:
116+
T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution.
117+
"""
85118
assert in_dtype in [
86119
"float16",
87120
"int8",
@@ -152,7 +185,24 @@ def main(
152185
B: T.Buffer(B_shape, storage_dtype),
153186
C: T.Buffer((M, N), out_dtype),
154187
):
155-
with T.Kernel(
188+
"""
189+
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.
190+
191+
This kernel:
192+
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
193+
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
194+
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
195+
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.
196+
197+
Parameters:
198+
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
199+
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
200+
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).
201+
202+
Side effects:
203+
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
204+
"""
205+
with T.Kernel(
156206
T.ceildiv(N, block_N),
157207
T.ceildiv(M, block_M),
158208
threads=threads,

examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py

Lines changed: 197 additions & 3 deletions
Large diffs are not rendered by default.

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 195 additions & 3 deletions
Large diffs are not rendered by default.

examples/dequantize_gemm/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33

44
def torch_convert_bit_twiddling(tensor):
55

6+
"""
7+
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.
8+
9+
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
10+
11+
Parameters:
12+
tensor (torch.Tensor): 2-D input tensor with dtype `torch.uint8`. Shape (N, K).
13+
14+
Returns:
15+
torch.Tensor: New tensor of dtype `torch.bfloat16` with shape (N, K*2), where each input column pair produces two bf16 output columns.
16+
17+
Raises:
18+
AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
19+
"""
620
def _convert(val0, val1, pos) -> torch.bfloat16:
721
assert val0.dtype == torch.uint8
822
assert val1.dtype == torch.uint8
@@ -38,6 +52,19 @@ def _convert(val0, val1, pos) -> torch.bfloat16:
3852

3953
def torch_convert(tensor, scale_size=None, Scale=None):
4054

55+
"""
56+
Decode a 2D uint8 tensor into a 2D bfloat16 tensor by expanding each byte into two bf16 values using a 4-bit (nibble) encoding.
57+
58+
Each input byte holds two 4-bit encoded values (low and high nibble). For each nibble this function derives sign/scale bits, a 3-bit exponent fragment and a 1-bit mantissa fragment, assembles a 16-bit bf16 pattern, and returns the resulting tensor with shape (N, K*2) and dtype torch.bfloat16 on the same device as the input.
59+
60+
Parameters:
61+
tensor (torch.Tensor): 2D tensor of dtype torch.uint8 and shape (N, K). Each byte contains two encoded 4-bit entries that become two bf16 values.
62+
scale_size (int, optional): If provided, controls how elements of the optional Scale tensor are indexed. When supplied, per-output-element scaling is applied to the exponent using Scale.
63+
Scale (torch.Tensor, optional): A 2D tensor used to supply per-element integer scale adjustments to the exponent. If scale_size is provided, the scale used for output element (i, j) is Scale[i][j // scale_size].
64+
65+
Returns:
66+
torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values.
67+
"""
4168
def _convert(val, pos, scale=None):
4269
assert val.dtype == torch.uint8
4370
# val = val.view(torch.int8)
@@ -67,6 +94,15 @@ def _convert(val, pos, scale=None):
6794

6895

6996
def print_bit(name, val):
97+
"""
98+
Print the 32-bit binary representation of a CPU scalar extracted from a PyTorch tensor.
99+
100+
Converts `val` to CPU, reads its Python scalar with `.item()`, formats it as a 32-bit binary string, and prints it prefixed by `name`.
101+
102+
Parameters:
103+
name (str): Label printed before the binary representation.
104+
val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown.
105+
"""
70106
val_cpu = val.cpu().item()
71107
binary_repr = f'{val_cpu:032b}'
72108
print(name, binary_repr)

examples/gemm/example_gemm_autotune.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,40 @@
1111

1212

1313
def ref_program(A, B):
14+
"""
15+
Compute the matrix product of A and the transpose of B.
16+
17+
A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes.
18+
"""
1419
return A @ B.T
1520

1621

1722
def get_configs(M, N, K, with_roller=False, topk=20):
23+
"""
24+
Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply.
25+
26+
When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended
27+
configurations (device-specific TensorCore-friendly tilings). Each returned dict contains:
28+
- block_M, block_N, block_K: tile sizes
29+
- num_stages: pipeline staging (0 means no explicit staging)
30+
- thread_num: total threads used for the block
31+
- enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling)
32+
33+
When with_roller is False this returns the Cartesian product of a fixed set of candidate
34+
parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag.
35+
36+
Parameters:
37+
M, N, K (int): GEMM dimensions used to generate valid tile sizes.
38+
with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints;
39+
otherwise use a predefined candidate grid.
40+
topk (int): Maximum number of roller hints to request when with_roller is True.
41+
42+
Returns:
43+
List[dict]: A list of configuration dictionaries as described above.
44+
45+
Raises:
46+
ValueError: if with_roller is True but the roller returns no hints.
47+
"""
1848
if with_roller:
1949
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
2050
carve_template = MatmulTemplate(

tilelang/intrinsics/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ def mfma_store_index_map(thread_id, local_id):
7676
def get_mma_micro_size(dtype: Literal["float16", "int8"]):
7777
# TODO(lei): FP8 related precision support.
7878
# Basic Tensor Core Matrix Multiply operation Unit
79+
"""
80+
Return the MMA (Tensor Core) micro-tile dimensions for a given data type.
81+
82+
This function returns the micro tile sizes (x, y, k) used by MMA/Tensor Core operations.
83+
- x: tile width in the output/result dimension
84+
- y: tile height in the output/result dimension
85+
- k: tile depth in the reduction/K dimension
86+
87+
Accepted dtype strings include "float16", "int8" and some FP8 identifiers ("float8_e4m3", "float8_e5m2"). For FP8 and int8 types the reduction depth (`k`) is 32; for float16 it is 16.
88+
89+
Returns:
90+
tuple[int, int, int]: (micro_size_x, micro_size_y, micro_size_k)
91+
"""
7992
micro_size_x = micro_size_y = 16
8093
micro_size_k = 16
8194
if dtype in {"float8_e4m3", "float8_e5m2", "int8"}:

tilelang/language/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
make_tensor, # noqa: F401
1818
Buffer, # noqa: F401
1919
Tensor, # noqa: F401
20-
StridedTensor, # noqa: F401
2120
FragmentBuffer, # noqa: F401
2221
SharedBuffer, # noqa: F401
2322
LocalBuffer, # noqa: F401
@@ -73,6 +72,16 @@
7372

7473

7574
def symbolic(name: str, dtype: str = "int32"):
75+
"""
76+
Create a TIR symbolic variable.
77+
78+
Parameters:
79+
name (str): Identifier for the variable in generated TIR.
80+
dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32".
81+
82+
Returns:
83+
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
84+
"""
7685
return tir.Var(name, dtype)
7786

7887

tilelang/language/utils.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,16 @@
44

55
def index_to_coordinates(index, shape) -> list[PrimExpr]:
66
"""
7-
Convert a flat (linear) index to multi-dimensional coordinates for a given shape.
8-
9-
Example:
10-
shape = (4, 5, 6)
11-
index = 53
12-
index_to_coordinates(53, (4, 5, 6)) -> [1, 3, 5]
13-
# Explanation:
14-
# 53 // (5*6) = 1 (1st coordinate)
15-
# 53 % (5*6) = 23
16-
# 23 // 6 = 3 (2nd coordinate)
17-
# 23 % 6 = 5 (3rd coordinate)
18-
19-
Args:
20-
index (int): The flat index to convert.
21-
shape (tuple or list of int): The shape of the multi-dimensional array.
22-
7+
Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
8+
9+
Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.
10+
11+
Parameters:
12+
index (int or PrimExpr): The flat index to convert.
13+
shape (Sequence[int]): The extents of each dimension (length >= 1).
14+
2315
Returns:
24-
list: A list of coordinates corresponding to each dimension.
16+
list[PrimExpr]: Coordinates for each dimension in the same order as `shape`.
2517
"""
2618
coordinates = []
2719
dims = len(shape)
@@ -34,18 +26,29 @@ def index_to_coordinates(index, shape) -> list[PrimExpr]:
3426

3527
def linear_index(*args: PrimExpr) -> PrimExpr:
3628
"""
37-
Convert a list of coordinates to a flat (linear) index using strides.
38-
39-
Usage examples:
40-
linear_index(i) -> i
41-
linear_index(i, j) -> i * stride + j
42-
linear_index(i, j, stride_j) -> i * stride_j + j
43-
linear_index(i, j, k, stride_j, stride_k)
44-
-> i * stride_j * stride_k + j * stride_k + k
45-
46-
Example for index = i * threads * local_size + tx * local_size + v:
47-
Suppose you have i, tx, v as coordinates, and threads, local_size as strides:
48-
linear_index(i, tx, v, threads, local_size) == i * threads * local_size + tx * local_size + v
29+
Compute a flat (linear) index from multi-dimensional coordinates and strides.
30+
31+
The function accepts a sequence of PrimExpr arguments where the first portion are coordinates
32+
and the trailing portion are the corresponding strides. The number of strides must equal
33+
(number of coordinates - 1). The linear index is computed as:
34+
35+
linear = coords[0]
36+
for each (coord, stride) in zip(coords[1:], strides):
37+
linear = linear * stride + coord
38+
39+
Examples:
40+
- linear_index(i) -> i
41+
- linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)
42+
- linear_index(i, j, stride_j) -> i * stride_j + j
43+
- linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k
44+
- linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v
45+
46+
Raises:
47+
ValueError: If called with no arguments, or if the number of strides is not one less than
48+
the number of coordinates.
49+
50+
Returns:
51+
PrimExpr: The computed linear index expression.
4952
"""
5053
n = len(args)
5154
if n == 0:

tilelang/quantize/mxfp.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,29 @@ def get_mxfp_intrin_group(
5656
use_twiddling: bool = False,
5757
) -> Dict[str, str]:
5858
"""
59-
This function is used to get the intrinsic group of the MXFP operation to avoid the overhead of fast decoding.
60-
MXFP is a type of logic operation that takes three inputs. The intrinsic group refers to the set of
61-
intrinsic operations that can be performed on these inputs. This function retrieves and returns this group.
59+
Return metadata for an MXFP decoding intrinsic: function name and C source string.
60+
61+
Validates the requested output dtype, source format, and storage dtype, then constructs
62+
a lookup key of the form `fp{source_bit}_to_{f16|bf16}` (appending `_twiddling` when
63+
use_twiddling is True) to select the corresponding C source snippet and a matching
64+
function name `decode_fp{source_bit}_to_{f16|bf16}` (also optionally suffixed with
65+
`_twiddling`).
66+
67+
Parameters:
68+
out_dtype: Target floating-point type for decoded values; either "float16" or "bfloat16".
69+
source_format: Integer source representation; "int" or "uint".
70+
source_bit: Bit width of the packed source format (e.g., 4).
71+
storage_dtype: Underlying storage integer dtype (one of "int32", "int8", "uint8").
72+
use_twiddling: When True, select the twiddling variant of the decoding intrinsic.
73+
74+
Returns:
75+
A dict with:
76+
- "func_name": the generated C function name string for the requested decode intrinsic.
77+
- "c_source": the C source string for that intrinsic.
78+
79+
Raises:
80+
AssertionError: if out_dtype, source_format, or storage_dtype are not supported.
81+
KeyError: if the constructed key does not match any available C source implementation.
6282
"""
6383
assert out_dtype in ["float16", "bfloat16"
6484
], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'."

tilelang/quantize/quantization.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,32 @@
2929
# fmt: off
3030
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
3131
dtype: str):
32-
assert nbit == 4
32+
"""
33+
Convert a packed 4-bit field stored in a uint8 into a bfloat16 value using an exponent scale.
34+
35+
This function expects a storage field of width `nbit == 4` packed into the 8-bit input `val` and returns
36+
a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa.
37+
38+
Behavior:
39+
- Validates `nbit == 4`, `dtype == "bfloat16"`, and `val.dtype == "uint8"` (AssertionError if violated).
40+
- Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`).
41+
- Interprets the 4-bit field as: sign = bit3, exponent = bits1-2, mantissa = bit0.
42+
- Converts the 2-bit exponent to bf16 exponent space by adding a bias of 126, adds `scale` to that exponent,
43+
and clamps the result to the 8-bit exponent range (0..255).
44+
- Assembles a 16-bit bfloat16 bit pattern from (sign, biased-and-scaled-exponent, mantissa) and
45+
returns it reinterpreted as `bfloat16`.
46+
47+
Parameters:
48+
- nbit: must be 4 (width of the packed field).
49+
- val: uint8 expression containing packed fields.
50+
- pos: index of the field within `val` (0-based); used to compute the bit shift.
51+
- scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression).
52+
- dtype: must be "bfloat16".
53+
54+
Returns:
55+
- A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value.
56+
"""
57+
assert nbit == 4
3358
assert dtype == "bfloat16"
3459
assert val.dtype == "uint8"
3560
mask = tir.const((1 << nbit) - 1, "uint16")
@@ -48,6 +73,21 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
4873
return val_bf16
4974

5075
def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True):
76+
"""
77+
Convert two float32 values to bfloat16 and pack them into a single uint32.
78+
79+
The two inputs v0 and v1 (float32 PrimExpr) are reinterpreted as uint32 bit patterns, optionally rounded to nearest-even
80+
by adding a rounding bias, then truncated to their upper 16 bits (bfloat16 representation). The two 16-bit results are
81+
packed into a uint32 with v0 in the lower 16 bits and v1 in the upper 16 bits.
82+
83+
Parameters:
84+
v0 (tir.PrimExpr): First float32 value to convert and pack.
85+
v1 (tir.PrimExpr): Second float32 value to convert and pack.
86+
round_to_even (bool): If True, apply round-to-nearest-even bias before truncation (default True).
87+
88+
Returns:
89+
tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits).
90+
"""
5191
mask = tir.const((1 << 16) - 1, "uint32")
5292
res = []
5393
for data in [v0, v1]:

0 commit comments

Comments
 (0)