|
| 1 | +# ruff: noqa |
| 2 | +import argparse |
| 3 | +import torch |
| 4 | +import tilelang |
| 5 | +import tilelang.language as T |
| 6 | +import tilelang.testing |
| 7 | +from einops import rearrange, repeat |
| 8 | +from tilelang.profiler import do_bench |
| 9 | +from varlen_utils import generate_random_padding_mask, generate_qkv |
| 10 | + |
| 11 | +tilelang.disable_cache() |
| 12 | + |
| 13 | + |
| 14 | +def attention_ref( |
| 15 | + q, |
| 16 | + k, |
| 17 | + v, |
| 18 | + query_padding_mask=None, |
| 19 | + key_padding_mask=None, |
| 20 | + causal=False, |
| 21 | + window_size=(-1, -1), |
| 22 | + upcast=True, |
| 23 | +): |
| 24 | + if causal: |
| 25 | + window_size = (window_size[0], 0) |
| 26 | + dtype_og = q.dtype |
| 27 | + if upcast: |
| 28 | + q, k, v = q.float(), k.float(), v.float() |
| 29 | + dim = q.shape[-1] |
| 30 | + scale = (1.0 / dim)**0.5 |
| 31 | + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) |
| 32 | + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) |
| 33 | + scores = torch.einsum("bthd,bshd->bhts", q, k) |
| 34 | + if key_padding_mask is not None: |
| 35 | + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) |
| 36 | + scores = scores * scale |
| 37 | + attention = torch.softmax(scores, dim=-1).to(v.dtype) |
| 38 | + |
| 39 | + if query_padding_mask is not None: |
| 40 | + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) |
| 41 | + output = torch.einsum("bhts,bshd->bthd", attention, v) |
| 42 | + if query_padding_mask is not None: |
| 43 | + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) |
| 44 | + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) |
| 45 | + |
| 46 | + |
| 47 | +@tilelang.jit( |
| 48 | + out_idx=[6], pass_configs={ |
| 49 | + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, |
| 50 | + }) |
| 51 | +def flashattn(batch_size, |
| 52 | + groups, |
| 53 | + UQ, |
| 54 | + UKV, |
| 55 | + heads, |
| 56 | + dim, |
| 57 | + is_causal, |
| 58 | + block_M=64, |
| 59 | + block_N=64, |
| 60 | + num_stages=1, |
| 61 | + threads=128): |
| 62 | + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) |
| 63 | + head_kv = heads // groups |
| 64 | + q_shape = [UQ, heads, dim] |
| 65 | + kv_shape = [UKV, head_kv, dim] |
| 66 | + o_shape = [UQ, heads, dim] |
| 67 | + dtype = "float16" |
| 68 | + accum_dtype = "float" |
| 69 | + |
| 70 | + @T.prim_func |
| 71 | + def main( |
| 72 | + Q_unpad: T.Tensor(q_shape, dtype), |
| 73 | + K_unpad: T.Tensor(kv_shape, dtype), |
| 74 | + V_unpad: T.Tensor(kv_shape, dtype), |
| 75 | + cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), |
| 76 | + cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), |
| 77 | + max_seqlen_q: T.int32, |
| 78 | + Output_unpad: T.Tensor(o_shape, dtype), |
| 79 | + ): |
| 80 | + with T.Kernel( |
| 81 | + T.ceildiv(max_seqlen_q, block_M), heads, batch_size, |
| 82 | + threads=threads) as (bx, by, bz): |
| 83 | + Q_shared = T.alloc_shared([block_M, dim], dtype) |
| 84 | + K_shared = T.alloc_shared([block_N, dim], dtype) |
| 85 | + V_shared = T.alloc_shared([block_N, dim], dtype) |
| 86 | + O_shared = T.alloc_shared([block_M, dim], dtype) |
| 87 | + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) |
| 88 | + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) |
| 89 | + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) |
| 90 | + scores_max = T.alloc_fragment([block_M], accum_dtype) |
| 91 | + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) |
| 92 | + scores_scale = T.alloc_fragment([block_M], accum_dtype) |
| 93 | + scores_sum = T.alloc_fragment([block_M], accum_dtype) |
| 94 | + logsum = T.alloc_fragment([block_M], accum_dtype) |
| 95 | + |
| 96 | + batch_idx = bz |
| 97 | + head_idx = by |
| 98 | + kv_head_idx = head_idx // groups |
| 99 | + |
| 100 | + q_start_idx = cu_seqlens_q[batch_idx] |
| 101 | + k_start_idx = cu_seqlens_k[batch_idx] |
| 102 | + v_start_idx = cu_seqlens_k[batch_idx] |
| 103 | + q_end_idx = cu_seqlens_q[batch_idx + 1] |
| 104 | + k_end_idx = cu_seqlens_k[batch_idx + 1] |
| 105 | + v_end_idx = cu_seqlens_k[batch_idx + 1] |
| 106 | + |
| 107 | + q_current_seqlen = q_end_idx - q_start_idx |
| 108 | + k_current_seqlen = k_end_idx - k_start_idx |
| 109 | + v_current_seqlen = v_end_idx - v_start_idx |
| 110 | + |
| 111 | + T.copy( |
| 112 | + Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :], |
| 113 | + Q_shared) |
| 114 | + for i, d in T.Parallel(block_M, dim): |
| 115 | + if bx * block_M + i >= q_current_seqlen: |
| 116 | + Q_shared[i, d] = 0 |
| 117 | + |
| 118 | + T.fill(acc_o, 0) |
| 119 | + T.fill(logsum, 0) |
| 120 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 121 | + |
| 122 | + loop_range = T.ceildiv(k_current_seqlen, block_N) |
| 123 | + |
| 124 | + for k in T.Pipelined(loop_range, num_stages=num_stages): |
| 125 | + T.copy( |
| 126 | + K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N, |
| 127 | + kv_head_idx, :], K_shared) |
| 128 | + for i, d in T.Parallel(block_N, dim): |
| 129 | + if k * block_N + i >= k_current_seqlen: |
| 130 | + K_shared[i, d] = 0 |
| 131 | + |
| 132 | + if is_causal: |
| 133 | + for i, j in T.Parallel(block_M, block_N): |
| 134 | + acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and |
| 135 | + (bx * block_M + i >= q_current_seqlen or |
| 136 | + k * block_N + j >= k_current_seqlen), |
| 137 | + -T.infinity(acc_s.dtype), 0) |
| 138 | + else: |
| 139 | + for i, j in T.Parallel(block_M, block_N): |
| 140 | + acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or |
| 141 | + k * block_N + j >= k_current_seqlen), |
| 142 | + -T.infinity(acc_s.dtype), 0) |
| 143 | + |
| 144 | + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) |
| 145 | + |
| 146 | + T.copy(scores_max, scores_max_prev) |
| 147 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 148 | + T.reduce_max(acc_s, scores_max, dim=1, clear=False) |
| 149 | + |
| 150 | + for i in T.Parallel(block_M): |
| 151 | + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) |
| 152 | + for i, j in T.Parallel(block_M, block_N): |
| 153 | + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) |
| 154 | + T.reduce_sum(acc_s, scores_sum, dim=1) |
| 155 | + for i in T.Parallel(block_M): |
| 156 | + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] |
| 157 | + T.copy(acc_s, acc_s_cast) |
| 158 | + |
| 159 | + for i, j in T.Parallel(block_M, dim): |
| 160 | + acc_o[i, j] *= scores_scale[i] |
| 161 | + |
| 162 | + T.copy( |
| 163 | + V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N, |
| 164 | + kv_head_idx, :], V_shared) |
| 165 | + for i, d in T.Parallel(block_N, dim): |
| 166 | + if k * block_N + i >= v_current_seqlen: |
| 167 | + V_shared[i, d] = 0 |
| 168 | + |
| 169 | + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) |
| 170 | + |
| 171 | + for i, j in T.Parallel(block_M, dim): |
| 172 | + acc_o[i, j] /= logsum[i] |
| 173 | + T.copy(acc_o, O_shared) |
| 174 | + |
| 175 | + for i, d in T.Parallel(block_M, dim): |
| 176 | + if bx * block_M + i < q_current_seqlen: |
| 177 | + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] |
| 178 | + |
| 179 | + return main |
| 180 | + |
| 181 | + |
| 182 | +def main(batch: int = 1, |
| 183 | + heads: int = 64, |
| 184 | + q_seqlen: int = 2048, |
| 185 | + k_seqlen: int = 2048, |
| 186 | + dim: int = 128, |
| 187 | + groups: int = 16, |
| 188 | + is_causal: bool = False): |
| 189 | + assert heads % groups == 0, "heads must be divisible by groups" |
| 190 | + |
| 191 | + flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim |
| 192 | + total_flops = 2 * flops_per_matmul |
| 193 | + |
| 194 | + tilelang.testing.set_random_seed(0) |
| 195 | + |
| 196 | + causal = False |
| 197 | + if causal: |
| 198 | + total_flops *= 0.5 |
| 199 | + |
| 200 | + tilelang.testing.set_random_seed(0) |
| 201 | + |
| 202 | + dtype = torch.float16 |
| 203 | + device = torch.device("cuda") |
| 204 | + |
| 205 | + head_kv = heads // groups |
| 206 | + q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True) |
| 207 | + k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) |
| 208 | + v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True) |
| 209 | + |
| 210 | + query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") |
| 211 | + key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") |
| 212 | + |
| 213 | + ( |
| 214 | + q_unpad, |
| 215 | + k_unpad, |
| 216 | + v_unpad, |
| 217 | + cu_seqlens_q, |
| 218 | + cu_seqlens_k, |
| 219 | + max_seqlen_q, |
| 220 | + max_seqlen_k, |
| 221 | + q, |
| 222 | + k, |
| 223 | + v, |
| 224 | + output_pad_fn, |
| 225 | + _, |
| 226 | + _, |
| 227 | + ) = generate_qkv( |
| 228 | + q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) |
| 229 | + |
| 230 | + UQ = q_unpad.shape[0] |
| 231 | + UKV = k_unpad.shape[0] |
| 232 | + |
| 233 | + kernel = flashattn( |
| 234 | + batch, |
| 235 | + groups, |
| 236 | + UQ, |
| 237 | + UKV, |
| 238 | + heads, |
| 239 | + dim, |
| 240 | + is_causal, |
| 241 | + block_M=64, |
| 242 | + block_N=64, |
| 243 | + num_stages=1, |
| 244 | + threads=128) |
| 245 | + |
| 246 | + out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) |
| 247 | + out = output_pad_fn(out_unpad) |
| 248 | + |
| 249 | + out_ref, _ = attention_ref( |
| 250 | + q, |
| 251 | + k, |
| 252 | + v, |
| 253 | + query_padding_mask=query_padding_mask, |
| 254 | + key_padding_mask=key_padding_mask, |
| 255 | + causal=is_causal, |
| 256 | + ) |
| 257 | + torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) |
| 258 | + print("All checks passed.✅") |
| 259 | + latency = do_bench( |
| 260 | + lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)) |
| 261 | + print("Tile-lang: {:.2f} ms".format(latency)) |
| 262 | + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) |
| 263 | + |
| 264 | + |
| 265 | +if __name__ == "__main__": |
| 266 | + parser = argparse.ArgumentParser() |
| 267 | + parser.add_argument('--batch', type=int, default=8, help='batch size') |
| 268 | + parser.add_argument('--heads', type=int, default=64, help='query heads') |
| 269 | + parser.add_argument('--groups', type=int, default=16, help='groups') |
| 270 | + parser.add_argument('--q_seqlen', type=int, default=2048, help='query sequence length') |
| 271 | + parser.add_argument('--k_seqlen', type=int, default=2048, help='key/value sequence length') |
| 272 | + parser.add_argument('--dim', type=int, default=128, help='head dim') |
| 273 | + parser.add_argument('--is_causal', action='store_true', help='causal attention') |
| 274 | + args = parser.parse_args() |
| 275 | + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, |
| 276 | + args.is_causal) |
0 commit comments