|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | + |
| 6 | +@triton.jit |
| 7 | +def fused_rotary_emb( |
| 8 | + q, |
| 9 | + k, |
| 10 | + cos_cache, |
| 11 | + sin_cache, |
| 12 | + cumsum_lengths, |
| 13 | + q_token_stride, |
| 14 | + q_head_stride, |
| 15 | + k_token_stride, |
| 16 | + k_head_stride, |
| 17 | + head_dim_stride, |
| 18 | + cos_token_stride, |
| 19 | + cos_dim_stride, |
| 20 | + q_total_tokens, |
| 21 | + Q_HEAD_NUM: tl.constexpr, |
| 22 | + K_HEAD_NUM: tl.constexpr, |
| 23 | + HEAD_DIM: tl.constexpr, |
| 24 | + BLOCK_HEAD: tl.constexpr, |
| 25 | + BLOCK_SIZE: tl.constexpr, |
| 26 | + N_ELEMENTS: tl.constexpr, |
| 27 | +): |
| 28 | + block_head_index = tl.program_id(0) |
| 29 | + block_group_index = tl.program_id(1) |
| 30 | + group_token_index = tl.program_id(2) |
| 31 | + idx = block_group_index * BLOCK_SIZE + group_token_index |
| 32 | + |
| 33 | + # original seq_idx and pos |
| 34 | + cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS)) |
| 35 | + ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0)) |
| 36 | + cos = tl.load( |
| 37 | + cos_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride |
| 38 | + ) # [1,HEAD_DIM//2] |
| 39 | + sin = tl.load(sin_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride) |
| 40 | + |
| 41 | + cur_head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) |
| 42 | + dim_range0 = tl.arange(0, HEAD_DIM // 2) |
| 43 | + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) |
| 44 | + |
| 45 | + off_q0 = ( |
| 46 | + idx * q_token_stride |
| 47 | + + cur_head_range[None, :, None] * q_head_stride |
| 48 | + + dim_range0[None, None, :] * head_dim_stride |
| 49 | + ) |
| 50 | + off_q1 = ( |
| 51 | + idx * q_token_stride |
| 52 | + + cur_head_range[None, :, None] * q_head_stride |
| 53 | + + dim_range1[None, None, :] * head_dim_stride |
| 54 | + ) |
| 55 | + |
| 56 | + off_k0 = ( |
| 57 | + idx * k_token_stride |
| 58 | + + cur_head_range[None, :, None] * k_head_stride |
| 59 | + + dim_range0[None, None, :] * head_dim_stride |
| 60 | + ) |
| 61 | + off_k1 = ( |
| 62 | + idx * q_token_stride |
| 63 | + + cur_head_range[None, :, None] * k_head_stride |
| 64 | + + dim_range1[None, None, :] * head_dim_stride |
| 65 | + ) |
| 66 | + |
| 67 | + q_0 = tl.load( |
| 68 | + q + off_q0, |
| 69 | + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), |
| 70 | + other=0.0, |
| 71 | + ) |
| 72 | + |
| 73 | + q_1 = tl.load( |
| 74 | + q + off_q1, |
| 75 | + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), |
| 76 | + other=0.0, |
| 77 | + ) |
| 78 | + |
| 79 | + k_0 = tl.load( |
| 80 | + k + off_k0, |
| 81 | + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), |
| 82 | + other=0.0, |
| 83 | + ) |
| 84 | + |
| 85 | + k_1 = tl.load( |
| 86 | + k + off_k1, |
| 87 | + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), |
| 88 | + other=0.0, |
| 89 | + ) |
| 90 | + |
| 91 | + out_q0 = q_0 * cos - q_1 * sin |
| 92 | + out_q1 = k_0 * sin + k_1 * cos |
| 93 | + |
| 94 | + out_k0 = q_0 * cos - q_1 * sin |
| 95 | + out_k1 = k_0 * sin + k_1 * cos |
| 96 | + # concat |
| 97 | + tl.store( |
| 98 | + q + off_q0, |
| 99 | + out_q0, |
| 100 | + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), |
| 101 | + ) |
| 102 | + tl.store( |
| 103 | + q + off_q1, |
| 104 | + out_q1, |
| 105 | + mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)), |
| 106 | + ) |
| 107 | + |
| 108 | + tl.store( |
| 109 | + k + off_k0, |
| 110 | + out_k0, |
| 111 | + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), |
| 112 | + ) |
| 113 | + tl.store( |
| 114 | + k + off_k1, |
| 115 | + out_k1, |
| 116 | + mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)), |
| 117 | + ) |
| 118 | + |
| 119 | + |
| 120 | +@torch.no_grad() |
| 121 | +def fused_rotary_embedding( |
| 122 | + q: torch.Tensor, |
| 123 | + k: torch.Tensor, |
| 124 | + cos: torch.Tensor, |
| 125 | + sin: torch.Tensor, |
| 126 | + lengths, |
| 127 | +): |
| 128 | + """ |
| 129 | + Args: |
| 130 | + q: query tensor, [total_tokens, head_num, head_dim] |
| 131 | + k: key tensor, [total_tokens, head_num, head_dim] |
| 132 | + cos: cosine for rotary embedding, [max_position_len, head_dim] |
| 133 | + sin: sine for rotary embedding, [max_position_len, head_dim] |
| 134 | + lengths [num_seqs] |
| 135 | + """ |
| 136 | + q_total_tokens, q_head_num, head_dim = q.shape |
| 137 | + assert q.size(0) == k.size(0) |
| 138 | + BLOCK_HEAD = 4 |
| 139 | + BLOCK_SIZE = 16 |
| 140 | + cumsum_lens = torch.cumsum(lengths, dim=0) |
| 141 | + |
| 142 | + grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE) |
| 143 | + |
| 144 | + if head_dim >= 128: |
| 145 | + num_warps = 8 |
| 146 | + else: |
| 147 | + num_warps = 4 |
| 148 | + |
| 149 | + q_token_stride = q.stride(0) |
| 150 | + q_head_stride = q.stride(1) |
| 151 | + head_dim_stride = q.stride(2) |
| 152 | + |
| 153 | + k_token_stride = k.stride(0) |
| 154 | + k_head_stride = k.stride(1) |
| 155 | + |
| 156 | + k_head_num = q.shape[1] |
| 157 | + |
| 158 | + cos_token_stride = cos.stride(0) |
| 159 | + cos_dim_stride = cos.stride(1) |
| 160 | + |
| 161 | + fused_rotary_emb[grid]( |
| 162 | + q, |
| 163 | + k, |
| 164 | + cos, |
| 165 | + sin, |
| 166 | + cumsum_lens, |
| 167 | + q_token_stride, |
| 168 | + q_head_stride, |
| 169 | + k_token_stride, |
| 170 | + k_head_stride, |
| 171 | + head_dim_stride, |
| 172 | + cos_token_stride, |
| 173 | + cos_dim_stride, |
| 174 | + q_total_tokens, |
| 175 | + Q_HEAD_NUM=q_head_num, |
| 176 | + K_HEAD_NUM=k_head_num, |
| 177 | + HEAD_DIM=head_dim, |
| 178 | + BLOCK_HEAD=BLOCK_HEAD, |
| 179 | + BLOCK_SIZE=BLOCK_SIZE, |
| 180 | + N_ELEMENTS=triton.next_power_of_2(q_total_tokens), |
| 181 | + num_warps=num_warps, |
| 182 | + ) |
0 commit comments