|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | + |
| 6 | +@triton.jit |
| 7 | +def rotary_embedding_kernel( |
| 8 | + q, |
| 9 | + k, |
| 10 | + cos, |
| 11 | + sin, |
| 12 | + q_token_stride, |
| 13 | + q_head_stride, |
| 14 | + k_token_stride, |
| 15 | + k_head_stride, |
| 16 | + head_dim_stride, |
| 17 | + cos_token_stride, |
| 18 | + cos_stride, |
| 19 | + q_total_tokens, |
| 20 | + Q_HEAD_NUM: tl.constexpr, |
| 21 | + K_HEAD_NUM: tl.constexpr, |
| 22 | + HEAD_DIM: tl.constexpr, |
| 23 | + BLOCK_HEAD: tl.constexpr, |
| 24 | + BLOCK_TOKENS: tl.constexpr, |
| 25 | +): |
| 26 | + block_head_index = tl.program_id(0) |
| 27 | + block_token_index = tl.program_id(1) |
| 28 | + |
| 29 | + rotary_data = q |
| 30 | + HEAD_NUM = Q_HEAD_NUM |
| 31 | + head_stride = q_head_stride |
| 32 | + token_stride = q_token_stride |
| 33 | + |
| 34 | + if block_token_index * BLOCK_TOKENS >= q_total_tokens: |
| 35 | + block_token_index = block_token_index - tl.cdiv(q_total_tokens, BLOCK_TOKENS) |
| 36 | + rotary_data = k |
| 37 | + HEAD_NUM = K_HEAD_NUM |
| 38 | + head_stride = k_head_stride |
| 39 | + token_stride = k_token_stride |
| 40 | + |
| 41 | + tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) |
| 42 | + head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) |
| 43 | + |
| 44 | + dim_range0 = tl.arange(0, HEAD_DIM // 2) |
| 45 | + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) |
| 46 | + |
| 47 | + off_data0 = ( |
| 48 | + tokens_range[:, None, None] * token_stride |
| 49 | + + head_range[None, :, None] * head_stride |
| 50 | + + dim_range0[None, None, :] * head_dim_stride |
| 51 | + ) |
| 52 | + off_data1 = ( |
| 53 | + tokens_range[:, None, None] * token_stride |
| 54 | + + head_range[None, :, None] * head_stride |
| 55 | + + dim_range1[None, None, :] * head_dim_stride |
| 56 | + ) |
| 57 | + |
| 58 | + loaded_data0 = tl.load( |
| 59 | + rotary_data + off_data0, |
| 60 | + mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), |
| 61 | + other=0.0, |
| 62 | + ) |
| 63 | + loaded_data1 = tl.load( |
| 64 | + rotary_data + off_data1, |
| 65 | + mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), |
| 66 | + other=0.0, |
| 67 | + ) |
| 68 | + |
| 69 | + off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride |
| 70 | + |
| 71 | + loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) |
| 72 | + loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0) |
| 73 | + |
| 74 | + out0 = loaded_data0 * loaded_cos[:, None, :] - loaded_data1 * loaded_sin[:, None, :] |
| 75 | + out1 = loaded_data0 * loaded_sin[:, None, :] + loaded_data1 * loaded_cos[:, None, :] |
| 76 | + |
| 77 | + # concat |
| 78 | + tl.store( |
| 79 | + rotary_data + off_data0, |
| 80 | + out0, |
| 81 | + mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), |
| 82 | + ) |
| 83 | + tl.store( |
| 84 | + rotary_data + off_data1, |
| 85 | + out1, |
| 86 | + mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)), |
| 87 | + ) |
| 88 | + |
| 89 | + |
| 90 | +@torch.no_grad() |
| 91 | +def rotary_embedding( |
| 92 | + q: torch.Tensor, |
| 93 | + k: torch.Tensor, |
| 94 | + cos: torch.Tensor, |
| 95 | + sin: torch.Tensor, |
| 96 | +): |
| 97 | + """ |
| 98 | + Args: |
| 99 | + q: query tensor, [total_tokens, head_num, head_dim] |
| 100 | + k: key tensor, [total_tokens, head_num, head_dim] |
| 101 | + cos: cosine for rotary embedding, [total_tokens, head_dim] |
| 102 | + sin: sine for rotary embedding, [total_tokens, head_dim] |
| 103 | + """ |
| 104 | + q_total_tokens, q_head_num, head_dim = q.shape |
| 105 | + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" |
| 106 | + BLOCK_HEAD = 4 |
| 107 | + BLOCK_TOKENS = 8 |
| 108 | + grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS)) |
| 109 | + |
| 110 | + if head_dim >= 128: |
| 111 | + num_warps = 8 |
| 112 | + else: |
| 113 | + num_warps = 4 |
| 114 | + |
| 115 | + q_token_stride = q.stride(0) |
| 116 | + q_head_stride = q.stride(1) |
| 117 | + head_dim_stride = q.stride(2) |
| 118 | + |
| 119 | + k_token_stride = k.stride(0) |
| 120 | + k_head_stride = k.stride(1) |
| 121 | + |
| 122 | + k_head_num = q.shape[1] |
| 123 | + |
| 124 | + cos_token_stride = cos.stride(0) |
| 125 | + cos_stride = cos.stride(1) |
| 126 | + |
| 127 | + rotary_embedding_kernel[grid]( |
| 128 | + q, |
| 129 | + k, |
| 130 | + cos, |
| 131 | + sin, |
| 132 | + q_token_stride, |
| 133 | + q_head_stride, |
| 134 | + k_token_stride, |
| 135 | + k_head_stride, |
| 136 | + head_dim_stride, |
| 137 | + cos_token_stride, |
| 138 | + cos_stride, |
| 139 | + q_total_tokens, |
| 140 | + Q_HEAD_NUM=q_head_num, |
| 141 | + K_HEAD_NUM=k_head_num, |
| 142 | + HEAD_DIM=head_dim, |
| 143 | + BLOCK_HEAD=BLOCK_HEAD, |
| 144 | + BLOCK_TOKENS=BLOCK_TOKENS, |
| 145 | + num_warps=num_warps, |
| 146 | + num_stages=1, |
| 147 | + ) |
| 148 | + |
| 149 | + return |
0 commit comments