Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,39 @@ def bitnet_158_int8xint2_prefill(
warp_col_tiles=32,
chunk=64,
):
"""
Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.

The returned prim_func expects:
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8").
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32").

Details:
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
- Tiling parameters:
- block_row_warps, block_col_warps: number of warps per block in row/col.
- warp_row_tiles, warp_col_tiles: tiles per warp.
- chunk: K-sized chunk per block (block_K).
- micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32").
- Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior.
- Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values.

Parameters:
M, N, K (int): Global matrix dimensions.
in_dtype (str): Input and decoded B element dtype; "float16" or "int8".
out_dtype (str): Output C dtype; one of "float16", "float32", "int32".
accum_dtype (str): Accumulator dtype used by MMA (e.g., "int32").
fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used).
block_row_warps (int): Warps in block row dimension.
block_col_warps (int): Warps in block column dimension.
warp_row_tiles (int): Tiles per warp in row dimension.
warp_col_tiles (int): Tiles per warp in column dimension.
chunk (int): K-length per block (block_K).

Returns:
T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution.
"""
assert in_dtype in [
"float16",
"int8",
Expand Down Expand Up @@ -152,6 +185,23 @@ def main(
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
):
"""
GPU kernel entry that performs a blocked, pipelined matrix multiplication A @ B.T writing into C.

This kernel:
- Loads tiles of A and a compressed/interleaved representation of B from global memory into shared memory.
- Decodes B's packed low-precision format (storage_dtype, e.g., 2-bit packed) into element values of `in_dtype` in shared memory via an external decode routine.
- Uses Warp/MMA tiled fragments and an INT4/INT2-capable MMA emitter to compute accumulation across K in a pipelined fashion with configurable stages.
- Writes accumulated tile results from shared memory back to global C with the expected block/micro-tile indexing.

Parameters:
A: Input matrix buffer of shape A_shape and element type `in_dtype`. Represents the MxK activations.
B: Compressed/interleaved weight buffer of shape B_shape and storage type `storage_dtype`. Must contain B in the packed low-precision layout expected by the decode routine used by this kernel.
C: Output buffer of shape (M, N) and type `out_dtype`; receives the resulting matrix (accumulated values are produced in `accum_dtype` and stored into C).

Side effects:
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
"""
with T.Kernel(
T.ceildiv(N, block_N),
T.ceildiv(M, block_M),
Expand Down
197 changes: 197 additions & 0 deletions examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py

Large diffs are not rendered by default.

193 changes: 192 additions & 1 deletion examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Large diffs are not rendered by default.

36 changes: 36 additions & 0 deletions examples/dequantize_gemm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@


def torch_convert_bit_twiddling(tensor):
"""
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.

This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.

Parameters:
tensor (torch.Tensor): 2-D input tensor with dtype `torch.uint8`. Shape (N, K).

Returns:
torch.Tensor: New tensor of dtype `torch.bfloat16` with shape (N, K*2), where each input column pair produces two bf16 output columns.

Raises:
AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
"""

def _convert(val0, val1, pos) -> torch.bfloat16:
assert val0.dtype == torch.uint8
Expand Down Expand Up @@ -37,6 +51,19 @@ def _convert(val0, val1, pos) -> torch.bfloat16:


def torch_convert(tensor, scale_size=None, Scale=None):
"""
Decode a 2D uint8 tensor into a 2D bfloat16 tensor by expanding each byte into two bf16 values using a 4-bit (nibble) encoding.

Each input byte holds two 4-bit encoded values (low and high nibble). For each nibble this function derives sign/scale bits, a 3-bit exponent fragment and a 1-bit mantissa fragment, assembles a 16-bit bf16 pattern, and returns the resulting tensor with shape (N, K*2) and dtype torch.bfloat16 on the same device as the input.

Parameters:
tensor (torch.Tensor): 2D tensor of dtype torch.uint8 and shape (N, K). Each byte contains two encoded 4-bit entries that become two bf16 values.
scale_size (int, optional): If provided, controls how elements of the optional Scale tensor are indexed. When supplied, per-output-element scaling is applied to the exponent using Scale.
Scale (torch.Tensor, optional): A 2D tensor used to supply per-element integer scale adjustments to the exponent. If scale_size is provided, the scale used for output element (i, j) is Scale[i][j // scale_size].

Returns:
torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values.
"""

def _convert(val, pos, scale=None):
assert val.dtype == torch.uint8
Expand Down Expand Up @@ -67,6 +94,15 @@ def _convert(val, pos, scale=None):


def print_bit(name, val):
"""
Print the 32-bit binary representation of a CPU scalar extracted from a PyTorch tensor.

Converts `val` to CPU, reads its Python scalar with `.item()`, formats it as a 32-bit binary string, and prints it prefixed by `name`.

Parameters:
name (str): Label printed before the binary representation.
val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown.
"""
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
30 changes: 30 additions & 0 deletions examples/gemm/example_gemm_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,40 @@


def ref_program(A, B):
"""
Compute the matrix product of A and the transpose of B.

A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes.
"""
return A @ B.T


def get_configs(M, N, K, with_roller=False, topk=20):
"""
Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply.

When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended
configurations (device-specific TensorCore-friendly tilings). Each returned dict contains:
- block_M, block_N, block_K: tile sizes
- num_stages: pipeline staging (0 means no explicit staging)
- thread_num: total threads used for the block
- enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling)

When with_roller is False this returns the Cartesian product of a fixed set of candidate
parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag.

Parameters:
M, N, K (int): GEMM dimensions used to generate valid tile sizes.
with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints;
otherwise use a predefined candidate grid.
topk (int): Maximum number of roller hints to request when with_roller is True.

Returns:
List[dict]: A list of configuration dictionaries as described above.

Raises:
ValueError: if with_roller is True but the roller returns no hints.
"""
if with_roller:
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
carve_template = MatmulTemplate(
Expand Down
13 changes: 13 additions & 0 deletions tilelang/intrinsics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ def mfma_store_index_map(thread_id, local_id):
def get_mma_micro_size(dtype: Literal["float16", "int8"]):
# TODO(lei): FP8 related precision support.
# Basic Tensor Core Matrix Multiply operation Unit
"""
Return the MMA (Tensor Core) micro-tile dimensions for a given data type.

This function returns the micro tile sizes (x, y, k) used by MMA/Tensor Core operations.
- x: tile width in the output/result dimension
- y: tile height in the output/result dimension
- k: tile depth in the reduction/K dimension

Accepted dtype strings include "float16", "int8" and some FP8 identifiers ("float8_e4m3", "float8_e5m2"). For FP8 and int8 types the reduction depth (`k`) is 32; for float16 it is 16.

Returns:
tuple[int, int, int]: (micro_size_x, micro_size_y, micro_size_k)
"""
micro_size_x = micro_size_y = 16
micro_size_k = 16
if dtype in {"float8_e4m3", "float8_e5m2", "int8"}:
Expand Down
11 changes: 10 additions & 1 deletion tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
make_tensor, # noqa: F401
Buffer, # noqa: F401
Tensor, # noqa: F401
StridedTensor, # noqa: F401
FragmentBuffer, # noqa: F401
SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401
Expand Down Expand Up @@ -73,6 +72,16 @@


def symbolic(name: str, dtype: str = "int32"):
"""
Create a TIR symbolic variable.
Parameters:
name (str): Identifier for the variable in generated TIR.
dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32".
Returns:
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
"""
return tir.Var(name, dtype)


Expand Down
61 changes: 32 additions & 29 deletions tilelang/language/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,16 @@

def index_to_coordinates(index, shape) -> list[PrimExpr]:
"""
Convert a flat (linear) index to multi-dimensional coordinates for a given shape.

Example:
shape = (4, 5, 6)
index = 53
index_to_coordinates(53, (4, 5, 6)) -> [1, 3, 5]
# Explanation:
# 53 // (5*6) = 1 (1st coordinate)
# 53 % (5*6) = 23
# 23 // 6 = 3 (2nd coordinate)
# 23 % 6 = 5 (3rd coordinate)

Args:
index (int): The flat index to convert.
shape (tuple or list of int): The shape of the multi-dimensional array.

Convert a flat (linear) index into multi-dimensional coordinates for a given shape.

Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.

Parameters:
index (int or PrimExpr): The flat index to convert.
shape (Sequence[int]): The extents of each dimension (length >= 1).

Returns:
list: A list of coordinates corresponding to each dimension.
list[PrimExpr]: Coordinates for each dimension in the same order as `shape`.
"""
coordinates = []
dims = len(shape)
Expand All @@ -34,18 +26,29 @@ def index_to_coordinates(index, shape) -> list[PrimExpr]:

def linear_index(*args: PrimExpr) -> PrimExpr:
"""
Convert a list of coordinates to a flat (linear) index using strides.

Usage examples:
linear_index(i) -> i
linear_index(i, j) -> i * stride + j
linear_index(i, j, stride_j) -> i * stride_j + j
linear_index(i, j, k, stride_j, stride_k)
-> i * stride_j * stride_k + j * stride_k + k

Example for index = i * threads * local_size + tx * local_size + v:
Suppose you have i, tx, v as coordinates, and threads, local_size as strides:
linear_index(i, tx, v, threads, local_size) == i * threads * local_size + tx * local_size + v
Compute a flat (linear) index from multi-dimensional coordinates and strides.

The function accepts a sequence of PrimExpr arguments where the first portion are coordinates
and the trailing portion are the corresponding strides. The number of strides must equal
(number of coordinates - 1). The linear index is computed as:

linear = coords[0]
for each (coord, stride) in zip(coords[1:], strides):
linear = linear * stride + coord

Examples:
- linear_index(i) -> i
- linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)
- linear_index(i, j, stride_j) -> i * stride_j + j
- linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k
- linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v

Raises:
ValueError: If called with no arguments, or if the number of strides is not one less than
the number of coordinates.

Returns:
PrimExpr: The computed linear index expression.
"""
n = len(args)
if n == 0:
Expand Down
26 changes: 23 additions & 3 deletions tilelang/quantize/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,29 @@ def get_mxfp_intrin_group(
use_twiddling: bool = False,
) -> Dict[str, str]:
"""
This function is used to get the intrinsic group of the MXFP operation to avoid the overhead of fast decoding.
MXFP is a type of logic operation that takes three inputs. The intrinsic group refers to the set of
intrinsic operations that can be performed on these inputs. This function retrieves and returns this group.
Return metadata for an MXFP decoding intrinsic: function name and C source string.

Validates the requested output dtype, source format, and storage dtype, then constructs
a lookup key of the form `fp{source_bit}_to_{f16|bf16}` (appending `_twiddling` when
use_twiddling is True) to select the corresponding C source snippet and a matching
function name `decode_fp{source_bit}_to_{f16|bf16}` (also optionally suffixed with
`_twiddling`).

Parameters:
out_dtype: Target floating-point type for decoded values; either "float16" or "bfloat16".
source_format: Integer source representation; "int" or "uint".
source_bit: Bit width of the packed source format (e.g., 4).
storage_dtype: Underlying storage integer dtype (one of "int32", "int8", "uint8").
use_twiddling: When True, select the twiddling variant of the decoding intrinsic.

Returns:
A dict with:
- "func_name": the generated C function name string for the requested decode intrinsic.
- "c_source": the C source string for that intrinsic.

Raises:
AssertionError: if out_dtype, source_format, or storage_dtype are not supported.
KeyError: if the constructed key does not match any available C source implementation.
"""
assert out_dtype in ["float16", "bfloat16"
], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'."
Expand Down
40 changes: 40 additions & 0 deletions tilelang/quantize/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,31 @@
# fmt: off
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
"""
Convert a packed 4-bit field stored in a uint8 into a bfloat16 value using an exponent scale.

This function expects a storage field of width `nbit == 4` packed into the 8-bit input `val` and returns
a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa.

Behavior:
- Validates `nbit == 4`, `dtype == "bfloat16"`, and `val.dtype == "uint8"` (AssertionError if violated).
- Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`).
- Interprets the 4-bit field as: sign = bit3, exponent = bits1-2, mantissa = bit0.
- Converts the 2-bit exponent to bf16 exponent space by adding a bias of 126, adds `scale` to that exponent,
and clamps the result to the 8-bit exponent range (0..255).
- Assembles a 16-bit bfloat16 bit pattern from (sign, biased-and-scaled-exponent, mantissa) and
returns it reinterpreted as `bfloat16`.

Parameters:
- nbit: must be 4 (width of the packed field).
- val: uint8 expression containing packed fields.
- pos: index of the field within `val` (0-based); used to compute the bit shift.
- scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression).
- dtype: must be "bfloat16".

Returns:
- A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
Expand All @@ -48,6 +73,21 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
return val_bf16

def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True):
"""
Convert two float32 values to bfloat16 and pack them into a single uint32.

The two inputs v0 and v1 (float32 PrimExpr) are reinterpreted as uint32 bit patterns, optionally rounded to nearest-even
by adding a rounding bias, then truncated to their upper 16 bits (bfloat16 representation). The two 16-bit results are
packed into a uint32 with v0 in the lower 16 bits and v1 in the upper 16 bits.

Parameters:
v0 (tir.PrimExpr): First float32 value to convert and pack.
v1 (tir.PrimExpr): Second float32 value to convert and pack.
round_to_even (bool): If True, apply round-to-nearest-even bias before truncation (default True).

Returns:
tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits).
"""
mask = tir.const((1 << 16) - 1, "uint32")
res = []
for data in [v0, v1]:
Expand Down
Loading