|  | 
|  | 1 | +import tilelang.language as T | 
|  | 2 | +from typing import Literal, Callable | 
|  | 3 | +from tvm.tir import IndexMap | 
|  | 4 | +from tilelang.intrinsics.utils import get_mma_micro_size | 
|  | 5 | + | 
|  | 6 | +from tilelang.intrinsics.mfma_layout import ( | 
|  | 7 | +    shared_16x4_to_local_64x1_layout_A, | 
|  | 8 | +    shared_16x16_to_local_64x4_layout_A, | 
|  | 9 | +    shared_16x32_to_local_64x8_layout_A, | 
|  | 10 | +    shared_16x64_to_local_64x16_layout_A, | 
|  | 11 | +) | 
|  | 12 | + | 
|  | 13 | + | 
|  | 14 | +def make_mfma_load_base_layout(dtype: str = "float16", | 
|  | 15 | +                               matrix: Literal["A", "B"] = "A", | 
|  | 16 | +                               k_dim: int = 16, | 
|  | 17 | +                               transposed: bool = False) -> T.Fragment: | 
|  | 18 | +    """ | 
|  | 19 | +    Create a layout function for storing MFMA results into a fragment buffer. | 
|  | 20 | +    This layout is used in conjunction with `inverse_mfma_store_layout` to | 
|  | 21 | +    map fragment indices to threads and local indices. | 
|  | 22 | +
 | 
|  | 23 | +    Parameters | 
|  | 24 | +    ---------- | 
|  | 25 | +    dtype : str | 
|  | 26 | +        The data type of the matrix. | 
|  | 27 | +    matrix : Literal["A", "B"] | 
|  | 28 | +        The mfma operand to be loaded. | 
|  | 29 | +    k_dim : int | 
|  | 30 | +        The k dimension of the mfma. | 
|  | 31 | +    transposed : bool | 
|  | 32 | +        Whether the matrix is transposed, by default False. | 
|  | 33 | +
 | 
|  | 34 | +    Returns | 
|  | 35 | +    ------- | 
|  | 36 | +    T.Fragment | 
|  | 37 | +        Describes how threads and indices in fragment are laid out. | 
|  | 38 | +
 | 
|  | 39 | +    """ | 
|  | 40 | + | 
|  | 41 | +    assert matrix in ["A", "B"], "matrix should be either A or B" | 
|  | 42 | +    # s represents spatial axis | 
|  | 43 | +    # r represents reduction axis | 
|  | 44 | +    # sr represents the two dims are spatial + reduction | 
|  | 45 | +    # rs represents the two dims are reduction + spatial | 
|  | 46 | +    transform_func_sr_a: Callable = None | 
|  | 47 | +    transform_func_sr_b: Callable = None | 
|  | 48 | + | 
|  | 49 | +    if k_dim == 4: | 
|  | 50 | +        transform_func_sr_a = shared_16x4_to_local_64x1_layout_A | 
|  | 51 | +        transform_func_sr_b = shared_16x4_to_local_64x1_layout_A | 
|  | 52 | +    elif k_dim == 16: | 
|  | 53 | +        transform_func_sr_a = shared_16x16_to_local_64x4_layout_A | 
|  | 54 | +        transform_func_sr_b = shared_16x16_to_local_64x4_layout_A | 
|  | 55 | +    elif k_dim == 32: | 
|  | 56 | +        transform_func_sr_a = shared_16x32_to_local_64x8_layout_A | 
|  | 57 | +        transform_func_sr_b = shared_16x32_to_local_64x8_layout_A | 
|  | 58 | +    elif k_dim == 64: | 
|  | 59 | +        transform_func_sr_a = shared_16x64_to_local_64x16_layout_A | 
|  | 60 | +        transform_func_sr_b = shared_16x64_to_local_64x16_layout_A | 
|  | 61 | +    else: | 
|  | 62 | +        raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") | 
|  | 63 | + | 
|  | 64 | +    is_sr_conditions = [False] | 
|  | 65 | +    is_sr_conditions.append(matrix == "A" and not transposed) | 
|  | 66 | +    is_sr_conditions.append(matrix == "B" and transposed) | 
|  | 67 | +    is_sr_axis_order = any(is_sr_conditions) | 
|  | 68 | + | 
|  | 69 | +    micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype) | 
|  | 70 | + | 
|  | 71 | +    # the layout of mma.sync is row.col. | 
|  | 72 | +    # so the b matrix expected a transposed basic layout | 
|  | 73 | +    transform_func: Callable = None | 
|  | 74 | +    if matrix == "A": | 
|  | 75 | +        transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( | 
|  | 76 | +            j, i) | 
|  | 77 | +        micro_size_s, micro_size_r = micro_size_x, micro_size_k | 
|  | 78 | +    elif matrix == "B": | 
|  | 79 | +        transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( | 
|  | 80 | +            j, i) | 
|  | 81 | +        micro_size_s, micro_size_r = micro_size_k, micro_size_y | 
|  | 82 | +    else: | 
|  | 83 | +        raise ValueError(f"Unsupported matrix {matrix}") | 
|  | 84 | + | 
|  | 85 | +    inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") | 
|  | 86 | + | 
|  | 87 | +    def forward_thread(i: int, j: int) -> int: | 
|  | 88 | +        """ | 
|  | 89 | +        Given the row index `i` and column index `j` in the fragment, | 
|  | 90 | +        """ | 
|  | 91 | +        lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) | 
|  | 92 | +        return lane_id | 
|  | 93 | + | 
|  | 94 | +    def forward_index(i: int, j: int) -> int: | 
|  | 95 | +        """ | 
|  | 96 | +        Given the row index `i` and column index `j` in the fragment, | 
|  | 97 | +        """ | 
|  | 98 | +        _, local_id = inverse_mma_load_layout.map_indices([i, j]) | 
|  | 99 | +        return local_id | 
|  | 100 | + | 
|  | 101 | +    base_fragment = T.Fragment( | 
|  | 102 | +        [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], | 
|  | 103 | +        forward_thread_fn=forward_thread, | 
|  | 104 | +        forward_index_fn=forward_index, | 
|  | 105 | +    ) | 
|  | 106 | +    return base_fragment | 
|  | 107 | + | 
|  | 108 | + | 
|  | 109 | +block_rows = 2 | 
|  | 110 | +block_cols = 2 | 
|  | 111 | +warp_rows = 2 | 
|  | 112 | +warp_cols = 2 | 
|  | 113 | +chunk = 2 | 
|  | 114 | + | 
|  | 115 | +from tilelang.tools import plot_layout | 
|  | 116 | + | 
|  | 117 | +# ldmatrix layout 16x16 | 
|  | 118 | +base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False) | 
|  | 119 | +print(base_layout) | 
|  | 120 | +plot_layout(base_layout, name="base_layout") | 
|  | 121 | + | 
|  | 122 | +# warp layout 32x32 | 
|  | 123 | +warp_layout = base_layout.repeat([warp_rows, warp_cols], | 
|  | 124 | +                                 repeat_on_thread=False, | 
|  | 125 | +                                 lower_dim_first=False) | 
|  | 126 | +print(warp_layout) | 
|  | 127 | +plot_layout(warp_layout, name="warp_layout") | 
|  | 128 | + | 
|  | 129 | +# block layout 64x32 | 
|  | 130 | +block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, | 
|  | 131 | +                                  lower_dim_first=True).replicate(block_cols) | 
|  | 132 | +print(block_layout) | 
|  | 133 | +plot_layout(block_layout, name="block_layout") | 
0 commit comments