|  | 
|  | 1 | +# Copyright (c) Tile-AI Corporation. | 
|  | 2 | +# Licensed under the MIT License. | 
|  | 3 | +# | 
|  | 4 | +# Modified to implement FlashAttention-2 forward pass principles. | 
|  | 5 | +# Corrected loop implementation using T.while_loop. | 
|  | 6 | + | 
|  | 7 | +import torch | 
|  | 8 | +import torch.nn.functional as F | 
|  | 9 | +import tilelang | 
|  | 10 | +import tilelang.language as T | 
|  | 11 | +import itertools | 
|  | 12 | +import argparse | 
|  | 13 | +from functools import partial | 
|  | 14 | + | 
|  | 15 | + | 
|  | 16 | +# PyTorch 参考实现保持不变 | 
|  | 17 | +def ref_program(Q, K, V, is_causal, groups=1): | 
|  | 18 | +    assert Q.size( | 
|  | 19 | +        2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" | 
|  | 20 | +    assert Q.size( | 
|  | 21 | +        2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" | 
|  | 22 | +    dim = Q.size(-1) | 
|  | 23 | +    K = K.repeat_interleave(groups, dim=2) | 
|  | 24 | +    V = V.repeat_interleave(groups, dim=2) | 
|  | 25 | +    scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) | 
|  | 26 | +    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) | 
|  | 27 | +    if is_causal: | 
|  | 28 | +        seq_len = Q.size(1) | 
|  | 29 | +        mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) | 
|  | 30 | +        mask = mask.unsqueeze(0).unsqueeze(0) | 
|  | 31 | +        scores = scores.masked_fill(mask == 0, float('-inf')) | 
|  | 32 | +    attention_weights = F.softmax(scores, dim=-1) | 
|  | 33 | +    output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) | 
|  | 34 | +    return output | 
|  | 35 | + | 
|  | 36 | + | 
|  | 37 | +def get_v2_configs(): | 
|  | 38 | +    """Generates configurations for the autotuner, tailored for FA-2 style parallelism.""" | 
|  | 39 | +    block_M = [64, 128, 256] | 
|  | 40 | +    block_N = [32, 64, 128] | 
|  | 41 | +    threads = [128, 256, 512] | 
|  | 42 | +    num_split_q = [32, 64, 128] | 
|  | 43 | +    num_stages = [1, 2, 3] | 
|  | 44 | +    enable_rasterization = [True] | 
|  | 45 | +    k_pack = [2] | 
|  | 46 | + | 
|  | 47 | +    valid_configs = [] | 
|  | 48 | + | 
|  | 49 | +    for m, n, s, t, stages, r, k in itertools.product(block_M, block_N, num_split_q, threads, | 
|  | 50 | +                                                      num_stages, enable_rasterization, k_pack): | 
|  | 51 | +        valid_configs.append({ | 
|  | 52 | +            "block_M": m, | 
|  | 53 | +            "block_N": n, | 
|  | 54 | +            "num_split_q": s, | 
|  | 55 | +            "threads": t, | 
|  | 56 | +            "num_stages": stages, | 
|  | 57 | +            "enable_rasterization": r, | 
|  | 58 | +            "k_pack": k | 
|  | 59 | +        }) | 
|  | 60 | +    if not valid_configs: | 
|  | 61 | +        valid_configs.append({ | 
|  | 62 | +            'block_M': 64, | 
|  | 63 | +            'block_N': 64, | 
|  | 64 | +            'num_split_q': 64, | 
|  | 65 | +            'threads': 256, | 
|  | 66 | +            'num_stages': 1, | 
|  | 67 | +            'enable_rasterization': True, | 
|  | 68 | +            'k_pack': 2 | 
|  | 69 | +        }) | 
|  | 70 | +    return valid_configs | 
|  | 71 | + | 
|  | 72 | + | 
|  | 73 | +@tilelang.autotune(configs=get_v2_configs(), cache_input_tensors=True) | 
|  | 74 | +@tilelang.jit(out_idx=[3]) | 
|  | 75 | +def fast_flashattn_v2( | 
|  | 76 | +    batch, | 
|  | 77 | +    heads, | 
|  | 78 | +    seq_len, | 
|  | 79 | +    dim, | 
|  | 80 | +    is_causal, | 
|  | 81 | +    groups, | 
|  | 82 | +    block_M: int, | 
|  | 83 | +    block_N: int, | 
|  | 84 | +    num_split_q: int, | 
|  | 85 | +    threads: int, | 
|  | 86 | +    num_stages: int, | 
|  | 87 | +    enable_rasterization: bool, | 
|  | 88 | +    k_pack: int, | 
|  | 89 | +): | 
|  | 90 | +    scale = (1.0 / dim)**0.5 * 1.44269504 | 
|  | 91 | +    head_kv = heads // groups | 
|  | 92 | +    q_shape = [batch, seq_len, heads, dim] | 
|  | 93 | +    kv_shape = [batch, seq_len, head_kv, dim] | 
|  | 94 | +    dtype = "float16" | 
|  | 95 | +    accum_dtype = "float" | 
|  | 96 | + | 
|  | 97 | +    v_vec_size = 4 | 
|  | 98 | + | 
|  | 99 | +    vec_size = 4 * k_pack | 
|  | 100 | + | 
|  | 101 | +    @T.macro | 
|  | 102 | +    def compute_block( | 
|  | 103 | +            bz, | 
|  | 104 | +            by, | 
|  | 105 | +            bx, | 
|  | 106 | +            Q: T.Tensor(q_shape, dtype), | 
|  | 107 | +            K: T.Tensor(kv_shape, dtype), | 
|  | 108 | +            V: T.Tensor(kv_shape, dtype), | 
|  | 109 | +            acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), | 
|  | 110 | +            m_i: T.FragmentBuffer([block_M], accum_dtype), | 
|  | 111 | +            l_i: T.FragmentBuffer([block_M], accum_dtype), | 
|  | 112 | +    ): | 
|  | 113 | +        Q_shared = T.alloc_shared([block_M, dim], dtype) | 
|  | 114 | +        K_shared = T.alloc_shared([block_N, dim], dtype) | 
|  | 115 | +        V_shared = T.alloc_shared([block_N, dim], dtype) | 
|  | 116 | +        P_shared = T.alloc_shared([block_M, block_N], dtype) | 
|  | 117 | + | 
|  | 118 | +        acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) | 
|  | 119 | +        m_prev = T.alloc_fragment([block_M], accum_dtype) | 
|  | 120 | +        scale_factor = T.alloc_fragment([block_M], accum_dtype) | 
|  | 121 | + | 
|  | 122 | +        q_block_offset = bx * block_M | 
|  | 123 | +        T.copy( | 
|  | 124 | +            Q[bz, q_block_offset:q_block_offset + block_M, by, :], | 
|  | 125 | +            Q_shared, | 
|  | 126 | +            coalesced_width=vec_size) | 
|  | 127 | + | 
|  | 128 | +        loop_end_k = T.ceildiv(q_block_offset + | 
|  | 129 | +                               block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) | 
|  | 130 | +        for k in T.Pipelined(loop_end_k, num_stages=num_stages): | 
|  | 131 | +            kv_idx = k * block_N | 
|  | 132 | +            T.copy( | 
|  | 133 | +                K[bz, kv_idx:kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) | 
|  | 134 | +            T.copy( | 
|  | 135 | +                V[bz, kv_idx:kv_idx + block_N, by // groups, :], | 
|  | 136 | +                V_shared, | 
|  | 137 | +                coalesced_width=v_vec_size) | 
|  | 138 | + | 
|  | 139 | +            T.clear(acc_s) | 
|  | 140 | +            T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, k_pack=k_pack) | 
|  | 141 | + | 
|  | 142 | +            if is_causal: | 
|  | 143 | +                for i, j in T.Parallel(block_M, block_N): | 
|  | 144 | +                    acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, acc_s[i, j], | 
|  | 145 | +                                                 -T.infinity(acc_s.dtype)) | 
|  | 146 | + | 
|  | 147 | +            T.copy(m_i, m_prev) | 
|  | 148 | +            T.reduce_max(acc_s, m_i, dim=1, clear=False) | 
|  | 149 | + | 
|  | 150 | +            for i in T.Parallel(block_M): | 
|  | 151 | +                sf = T.exp2(m_prev[i] * scale - m_i[i] * scale) | 
|  | 152 | +                l_i[i] *= sf | 
|  | 153 | +                scale_factor[i] = sf | 
|  | 154 | + | 
|  | 155 | +            for i, j in T.Parallel(block_M, dim): | 
|  | 156 | +                acc_o[i, j] *= scale_factor[i] | 
|  | 157 | + | 
|  | 158 | +            for i, j in T.Parallel(block_M, block_N): | 
|  | 159 | +                acc_s[i, j] = T.exp2(acc_s[i, j] * scale - m_i[i] * scale) | 
|  | 160 | + | 
|  | 161 | +            row_sum = T.alloc_fragment([block_M], accum_dtype) | 
|  | 162 | +            T.reduce_sum(acc_s, row_sum, dim=1) | 
|  | 163 | +            for i in T.Parallel(block_M): | 
|  | 164 | +                l_i[i] += row_sum[i] | 
|  | 165 | + | 
|  | 166 | +            T.copy(acc_s, P_shared) | 
|  | 167 | +            T.sync_threads() | 
|  | 168 | + | 
|  | 169 | +            T.gemm(P_shared, V_shared, acc_o) | 
|  | 170 | + | 
|  | 171 | +    # 修复:将宏移至内核外部,以实现清晰的代码结构。 | 
|  | 172 | +    @T.macro | 
|  | 173 | +    def scale_and_write_back(src_buffer, scale_vector, dest_tensor, bz, by, q_block_offset): | 
|  | 174 | +        # 此宏执行融合的缩放和写回操作,这对性能至关重要。 | 
|  | 175 | +        for i, j in T.Parallel(block_M, dim): | 
|  | 176 | +            dest_tensor[bz, q_block_offset + i, by, j] = src_buffer[i, j] * scale_vector[i] | 
|  | 177 | + | 
|  | 178 | +    @T.macro | 
|  | 179 | +    def flash_attn_forward_kernel(Q: T.Tensor(q_shape, dtype), K: T.Tensor(kv_shape, dtype), | 
|  | 180 | +                                  V: T.Tensor(kv_shape, dtype), Output: T.Tensor(q_shape, dtype)): | 
|  | 181 | +        with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): | 
|  | 182 | +            T.use_swizzle(10, enable=enable_rasterization) | 
|  | 183 | + | 
|  | 184 | +            bz = byz_combined // heads | 
|  | 185 | +            by = byz_combined % heads | 
|  | 186 | + | 
|  | 187 | +            num_q_blocks = T.ceildiv(seq_len, block_M) | 
|  | 188 | + | 
|  | 189 | +            bx = T.alloc_var("int32") | 
|  | 190 | +            bx[0] = b_split | 
|  | 191 | + | 
|  | 192 | +            with T.While(bx[0] < num_q_blocks): | 
|  | 193 | +                acc_o = T.alloc_fragment([block_M, dim], accum_dtype) | 
|  | 194 | +                m_i = T.alloc_fragment([block_M], accum_dtype) | 
|  | 195 | +                l_i = T.alloc_fragment([block_M], accum_dtype) | 
|  | 196 | +                T.fill(acc_o, 0) | 
|  | 197 | +                T.fill(m_i, -T.infinity(accum_dtype)) | 
|  | 198 | +                T.fill(l_i, 0) | 
|  | 199 | + | 
|  | 200 | +                current_bx = bx[0] | 
|  | 201 | + | 
|  | 202 | +                compute_block(bz, by, current_bx, Q, K, V, acc_o, m_i, l_i) | 
|  | 203 | + | 
|  | 204 | +                l_inv = T.alloc_fragment([block_M], accum_dtype) | 
|  | 205 | +                for i in T.Parallel(block_M): | 
|  | 206 | +                    safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0) | 
|  | 207 | +                    l_inv[i] = 1.0 / safe_l | 
|  | 208 | + | 
|  | 209 | +                # 修复:现在对宏的调用对编译器来说更清晰。 | 
|  | 210 | +                q_block_offset = current_bx * block_M | 
|  | 211 | +                scale_and_write_back(acc_o, l_inv, Output, bz, by, q_block_offset) | 
|  | 212 | + | 
|  | 213 | +                bx[0] = current_bx + num_split_q | 
|  | 214 | + | 
|  | 215 | +    @T.prim_func | 
|  | 216 | +    def main( | 
|  | 217 | +            Q: T.Tensor(q_shape, dtype), | 
|  | 218 | +            K: T.Tensor(kv_shape, dtype), | 
|  | 219 | +            V: T.Tensor(kv_shape, dtype), | 
|  | 220 | +            Output: T.Tensor(q_shape, dtype), | 
|  | 221 | +    ): | 
|  | 222 | +        flash_attn_forward_kernel(Q, K, V, Output) | 
|  | 223 | + | 
|  | 224 | +    return main | 
|  | 225 | + | 
|  | 226 | + | 
|  | 227 | +# main 函数保持不变 | 
|  | 228 | +def main_v2(batch: int = 1, | 
|  | 229 | +            heads: int = 8, | 
|  | 230 | +            seq_len: int = 4096, | 
|  | 231 | +            dim: int = 128, | 
|  | 232 | +            is_causal: bool = False, | 
|  | 233 | +            groups: int = 1): | 
|  | 234 | + | 
|  | 235 | +    flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim | 
|  | 236 | +    total_flops = 2 * flops_per_matmul | 
|  | 237 | +    if is_causal: | 
|  | 238 | +        total_flops *= 0.5 | 
|  | 239 | + | 
|  | 240 | +    print("Starting autotuning for FlashAttention-V2...") | 
|  | 241 | +    kernel = fast_flashattn_v2(batch, heads, seq_len, dim, is_causal, groups=groups) | 
|  | 242 | +    print(f"Autotuning finished. Best Configuration: {kernel.config}") | 
|  | 243 | + | 
|  | 244 | +    ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) | 
|  | 245 | + | 
|  | 246 | +    profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) | 
|  | 247 | + | 
|  | 248 | +    print("Verifying correctness...") | 
|  | 249 | +    profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) | 
|  | 250 | +    print("All checks pass.") | 
|  | 251 | + | 
|  | 252 | +    latency = profiler.do_bench(ref_program_processed, warmup=100) | 
|  | 253 | +    print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") | 
|  | 254 | + | 
|  | 255 | +    latency = profiler.do_bench(warmup=100) | 
|  | 256 | +    print( | 
|  | 257 | +        f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops" | 
|  | 258 | +    ) | 
|  | 259 | + | 
|  | 260 | + | 
|  | 261 | +if __name__ == "__main__": | 
|  | 262 | +    parser = argparse.ArgumentParser() | 
|  | 263 | +    parser.add_argument('--batch', type=int, default=1, help='batch size') | 
|  | 264 | +    parser.add_argument('--heads', type=int, default=8, help='heads') | 
|  | 265 | +    parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') | 
|  | 266 | +    parser.add_argument('--dim', type=int, default=128, help='dim') | 
|  | 267 | +    parser.add_argument('--is_causal', action='store_true', help='causal') | 
|  | 268 | +    parser.add_argument('--groups', type=int, default=1, help='groups') | 
|  | 269 | +    args = parser.parse_args() | 
|  | 270 | +    main_v2(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) | 
0 commit comments