|
| 1 | +import tilelang.language as T |
| 2 | +from typing import Literal, Callable |
| 3 | +from tvm.tir import IndexMap |
| 4 | +from tilelang.intrinsics.utils import get_mma_micro_size |
| 5 | + |
| 6 | +from tilelang.intrinsics.mfma_layout import ( |
| 7 | + shared_16x4_to_local_64x1_layout_A, |
| 8 | + shared_16x16_to_local_64x4_layout_A, |
| 9 | + shared_16x32_to_local_64x8_layout_A, |
| 10 | + shared_16x64_to_local_64x16_layout_A, |
| 11 | +) |
| 12 | + |
| 13 | + |
| 14 | +def make_mfma_load_base_layout(dtype: str = "float16", |
| 15 | + matrix: Literal["A", "B"] = "A", |
| 16 | + k_dim: int = 16, |
| 17 | + transposed: bool = False) -> T.Fragment: |
| 18 | + """ |
| 19 | + Create a layout function for storing MFMA results into a fragment buffer. |
| 20 | + This layout is used in conjunction with `inverse_mfma_store_layout` to |
| 21 | + map fragment indices to threads and local indices. |
| 22 | +
|
| 23 | + Parameters |
| 24 | + ---------- |
| 25 | + dtype : str |
| 26 | + The data type of the matrix. |
| 27 | + matrix : Literal["A", "B"] |
| 28 | + The mfma operand to be loaded. |
| 29 | + k_dim : int |
| 30 | + The k dimension of the mfma. |
| 31 | + transposed : bool |
| 32 | + Whether the matrix is transposed, by default False. |
| 33 | +
|
| 34 | + Returns |
| 35 | + ------- |
| 36 | + T.Fragment |
| 37 | + Describes how threads and indices in fragment are laid out. |
| 38 | +
|
| 39 | + """ |
| 40 | + |
| 41 | + assert matrix in ["A", "B"], "matrix should be either A or B" |
| 42 | + # s represents spatial axis |
| 43 | + # r represents reduction axis |
| 44 | + # sr represents the two dims are spatial + reduction |
| 45 | + # rs represents the two dims are reduction + spatial |
| 46 | + transform_func_sr_a: Callable = None |
| 47 | + transform_func_sr_b: Callable = None |
| 48 | + |
| 49 | + if k_dim == 4: |
| 50 | + transform_func_sr_a = shared_16x4_to_local_64x1_layout_A |
| 51 | + transform_func_sr_b = shared_16x4_to_local_64x1_layout_A |
| 52 | + elif k_dim == 16: |
| 53 | + transform_func_sr_a = shared_16x16_to_local_64x4_layout_A |
| 54 | + transform_func_sr_b = shared_16x16_to_local_64x4_layout_A |
| 55 | + elif k_dim == 32: |
| 56 | + transform_func_sr_a = shared_16x32_to_local_64x8_layout_A |
| 57 | + transform_func_sr_b = shared_16x32_to_local_64x8_layout_A |
| 58 | + elif k_dim == 64: |
| 59 | + transform_func_sr_a = shared_16x64_to_local_64x16_layout_A |
| 60 | + transform_func_sr_b = shared_16x64_to_local_64x16_layout_A |
| 61 | + else: |
| 62 | + raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") |
| 63 | + |
| 64 | + is_sr_conditions = [False] |
| 65 | + is_sr_conditions.append(matrix == "A" and not transposed) |
| 66 | + is_sr_conditions.append(matrix == "B" and transposed) |
| 67 | + is_sr_axis_order = any(is_sr_conditions) |
| 68 | + |
| 69 | + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype) |
| 70 | + |
| 71 | + # the layout of mma.sync is row.col. |
| 72 | + # so the b matrix expected a transposed basic layout |
| 73 | + transform_func: Callable = None |
| 74 | + if matrix == "A": |
| 75 | + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( |
| 76 | + j, i) |
| 77 | + micro_size_s, micro_size_r = micro_size_x, micro_size_k |
| 78 | + elif matrix == "B": |
| 79 | + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( |
| 80 | + j, i) |
| 81 | + micro_size_s, micro_size_r = micro_size_k, micro_size_y |
| 82 | + else: |
| 83 | + raise ValueError(f"Unsupported matrix {matrix}") |
| 84 | + |
| 85 | + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") |
| 86 | + |
| 87 | + def forward_thread(i: int, j: int) -> int: |
| 88 | + """ |
| 89 | + Given the row index `i` and column index `j` in the fragment, |
| 90 | + """ |
| 91 | + lane_id, _ = inverse_mma_load_layout.map_indices([i, j]) |
| 92 | + return lane_id |
| 93 | + |
| 94 | + def forward_index(i: int, j: int) -> int: |
| 95 | + """ |
| 96 | + Given the row index `i` and column index `j` in the fragment, |
| 97 | + """ |
| 98 | + _, local_id = inverse_mma_load_layout.map_indices([i, j]) |
| 99 | + return local_id |
| 100 | + |
| 101 | + base_fragment = T.Fragment( |
| 102 | + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], |
| 103 | + forward_thread_fn=forward_thread, |
| 104 | + forward_index_fn=forward_index, |
| 105 | + ) |
| 106 | + return base_fragment |
| 107 | + |
| 108 | + |
| 109 | +block_rows = 2 |
| 110 | +block_cols = 2 |
| 111 | +warp_rows = 2 |
| 112 | +warp_cols = 2 |
| 113 | +chunk = 2 |
| 114 | + |
| 115 | +from tilelang.tools import plot_layout |
| 116 | + |
| 117 | +# ldmatrix layout 16x16 |
| 118 | +base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False) |
| 119 | +print(base_layout) |
| 120 | +plot_layout(base_layout, name="base_layout") |
| 121 | + |
| 122 | +# warp layout 32x32 |
| 123 | +warp_layout = base_layout.repeat([warp_rows, warp_cols], |
| 124 | + repeat_on_thread=False, |
| 125 | + lower_dim_first=False) |
| 126 | +print(warp_layout) |
| 127 | +plot_layout(warp_layout, name="warp_layout") |
| 128 | + |
| 129 | +# block layout 64x32 |
| 130 | +block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True, |
| 131 | + lower_dim_first=True).replicate(block_cols) |
| 132 | +print(block_layout) |
| 133 | +plot_layout(block_layout, name="block_layout") |
0 commit comments