|
| 1 | +#include <ATen/cuda/CUDAContext.h> |
| 2 | +#include <torch/extension.h> |
| 3 | + |
| 4 | +#include "utils/vector_copy_utils.h" |
| 5 | +#include "../common/micros.h" |
| 6 | +#include "stdio.h" |
| 7 | + |
| 8 | +template <typename scalar_t, bool Aligned, int VecSize> |
| 9 | +__device__ void apply_cos_and_sin_memcopy( |
| 10 | + scalar_t* __restrict__ cos, |
| 11 | + scalar_t* __restrict__ sin, |
| 12 | + const scalar_t* __restrict__ cos_cache_ptr, |
| 13 | + const scalar_t* __restrict__ sin_cache_ptr, |
| 14 | + const int* __restrict__ sequence_lengths, |
| 15 | + const int head_dim, |
| 16 | + const int dest_offset_id, |
| 17 | + const int src_offset_id |
| 18 | + ) { |
| 19 | + |
| 20 | + int begin_id = threadIdx.x * VecSize; |
| 21 | + |
| 22 | + for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){ |
| 23 | + copy_vector<scalar_t, VecSize>(cos + dest_offset_id + begin_id, cos_cache_ptr + src_offset_id + begin_id); |
| 24 | + copy_vector<scalar_t, VecSize>(sin + dest_offset_id + begin_id, sin_cache_ptr + src_offset_id + begin_id); |
| 25 | + } |
| 26 | + |
| 27 | + if (!Aligned) { |
| 28 | + for (; begin_id < head_dim; ++begin_id ) { |
| 29 | + cos[dest_offset_id + begin_id] = cos_cache_ptr[src_offset_id + begin_id]; |
| 30 | + sin[dest_offset_id + begin_id] = sin_cache_ptr[src_offset_id + begin_id]; |
| 31 | + } |
| 32 | + } |
| 33 | +} |
| 34 | + |
| 35 | +template <typename scalar_t, bool Aligned, int VecSize> |
| 36 | +__global__ void apply_get_context_cos_and_sin_kernel( |
| 37 | + scalar_t* __restrict__ cos, |
| 38 | + scalar_t* __restrict__ sin, |
| 39 | + const scalar_t* __restrict__ cos_cache_ptr, |
| 40 | + const scalar_t* __restrict__ sin_cache_ptr, |
| 41 | + const int* __restrict__ sequence_lengths, |
| 42 | + const int* __restrict__ cumsum_lengths, |
| 43 | + const int batch_size, |
| 44 | + const int head_dim |
| 45 | +) { |
| 46 | + int token_id = blockIdx.x; |
| 47 | + if ( token_id >= sequence_lengths[blockIdx.y] ) { |
| 48 | + return ; |
| 49 | + } |
| 50 | + |
| 51 | + int src_offset_id = token_id * head_dim; |
| 52 | + int dest_offset_id = src_offset_id; |
| 53 | + |
| 54 | + if (blockIdx.y > 0) { |
| 55 | + dest_offset_id += cumsum_lengths[blockIdx.y - 1] * head_dim; |
| 56 | + } |
| 57 | + |
| 58 | + apply_cos_and_sin_memcopy<scalar_t, Aligned, VecSize>( |
| 59 | + cos, |
| 60 | + sin, |
| 61 | + cos_cache_ptr, |
| 62 | + sin_cache_ptr, |
| 63 | + sequence_lengths, |
| 64 | + head_dim, |
| 65 | + dest_offset_id, |
| 66 | + src_offset_id |
| 67 | + ); |
| 68 | + |
| 69 | +} |
| 70 | + |
| 71 | +template <typename scalar_t, bool Aligned, int VecSize> |
| 72 | +__global__ void apply_get_decode_cos_and_sin_kernel( |
| 73 | + scalar_t* __restrict__ cos, |
| 74 | + scalar_t* __restrict__ sin, |
| 75 | + const scalar_t* __restrict__ cos_cache_ptr, |
| 76 | + const scalar_t* __restrict__ sin_cache_ptr, |
| 77 | + const int* __restrict__ sequence_lengths, |
| 78 | + const int batch_size, |
| 79 | + const int head_dim |
| 80 | +) { |
| 81 | + int src_offset_id = ( sequence_lengths[blockIdx.y] - 1 ) * head_dim; |
| 82 | + int dest_offset_id = blockIdx.y * head_dim; |
| 83 | + |
| 84 | + apply_cos_and_sin_memcopy<scalar_t, Aligned, VecSize>( |
| 85 | + cos, |
| 86 | + sin, |
| 87 | + cos_cache_ptr, |
| 88 | + sin_cache_ptr, |
| 89 | + sequence_lengths, |
| 90 | + head_dim, |
| 91 | + dest_offset_id, |
| 92 | + src_offset_id |
| 93 | + ); |
| 94 | +} |
| 95 | + |
| 96 | +template<typename scalar_t> |
| 97 | +void apply_get_cos_and_sin( |
| 98 | + at::Tensor& cos_cache, // [max_rotary_position, head_dim] |
| 99 | + at::Tensor& sin_cache, // [max_rotary_position, head_dim] |
| 100 | + at::Tensor& cos, // [num_tokens, head_dim] |
| 101 | + at::Tensor& sin, // [num_tokens, head_dim] |
| 102 | + at::Tensor& sequence_lengths, // [batch_size] |
| 103 | + int max_seq_len_in_batch, |
| 104 | + bool is_prompts |
| 105 | +) { |
| 106 | + int token_num = cos.size(0); |
| 107 | + int head_dim = cos.size(1); |
| 108 | + int batch_size = sequence_lengths.size(0); |
| 109 | + |
| 110 | + at::Tensor cumsum_lengths; |
| 111 | + |
| 112 | + int vec_size = get_vec_size<scalar_t>(cos); |
| 113 | + |
| 114 | + bool aligned = true; |
| 115 | + if (head_dim % vec_size != 0) { |
| 116 | + aligned = false; |
| 117 | + } |
| 118 | + |
| 119 | + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 120 | + int block_size_y; |
| 121 | + int block_size_x; |
| 122 | + |
| 123 | + if (is_prompts) { |
| 124 | + block_size_y = batch_size; |
| 125 | + block_size_x = max_seq_len_in_batch; |
| 126 | + // TODO: The cumsum operation can be fused into get_cos_and_sin kernel later on. |
| 127 | + cumsum_lengths = torch::cumsum(sequence_lengths, 0, torch::kInt32); |
| 128 | + } |
| 129 | + else{ |
| 130 | + block_size_y = batch_size; |
| 131 | + block_size_x = 1; |
| 132 | + } |
| 133 | + |
| 134 | + int thread_nums = (head_dim + vec_size - 1) / vec_size; |
| 135 | + |
| 136 | + dim3 grid(block_size_x, block_size_y); |
| 137 | + dim3 block(std::min(thread_nums, 512)); |
| 138 | + |
| 139 | +#define GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, __vec_size) \ |
| 140 | + do { \ |
| 141 | + if (is_prompts){ \ |
| 142 | + apply_get_context_cos_and_sin_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \ |
| 143 | + cos.data_ptr<scalar_t>(), \ |
| 144 | + sin.data_ptr<scalar_t>(), \ |
| 145 | + cos_cache.data_ptr<scalar_t>(), \ |
| 146 | + sin_cache.data_ptr<scalar_t>(), \ |
| 147 | + sequence_lengths.data_ptr<int>(), \ |
| 148 | + cumsum_lengths.data_ptr<int>(), \ |
| 149 | + batch_size, \ |
| 150 | + head_dim \ |
| 151 | + ); \ |
| 152 | + } \ |
| 153 | + else { \ |
| 154 | + apply_get_decode_cos_and_sin_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \ |
| 155 | + cos.data_ptr<scalar_t>(), \ |
| 156 | + sin.data_ptr<scalar_t>(), \ |
| 157 | + cos_cache.data_ptr<scalar_t>(), \ |
| 158 | + sin_cache.data_ptr<scalar_t>(), \ |
| 159 | + sequence_lengths.data_ptr<int>(), \ |
| 160 | + batch_size, \ |
| 161 | + head_dim \ |
| 162 | + ); \ |
| 163 | + } \ |
| 164 | + } while(0) |
| 165 | + |
| 166 | +#define GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \ |
| 167 | + do { \ |
| 168 | + switch (vec_size) { \ |
| 169 | + case 1: \ |
| 170 | + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 1); \ |
| 171 | + break; \ |
| 172 | + case 2: \ |
| 173 | + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 2); \ |
| 174 | + break; \ |
| 175 | + case 4: \ |
| 176 | + GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 4); \ |
| 177 | + break; \ |
| 178 | + default: \ |
| 179 | + AT_ERROR("Unsupported vectorized size ", vec_size); \ |
| 180 | + break; \ |
| 181 | + } \ |
| 182 | + } while(0) |
| 183 | + |
| 184 | + if (aligned) { |
| 185 | + GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(true); |
| 186 | + } |
| 187 | + else { |
| 188 | + GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(false); |
| 189 | + } |
| 190 | + |
| 191 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 192 | +} |
| 193 | + |
| 194 | +void get_cos_and_sin( |
| 195 | + at::Tensor& cos_cache, // [max_rotary_position, head_dim] |
| 196 | + at::Tensor& sin_cache, // [max_rotary_position, head_dim] |
| 197 | + at::Tensor& cos, // [num_tokens, head_dim] |
| 198 | + at::Tensor& sin, // [num_tokens, head_dim] |
| 199 | + at::Tensor& sequence_lengths, // [batch_size] |
| 200 | + int max_seq_len_in_batch, |
| 201 | + bool is_prompts |
| 202 | +) { |
| 203 | + DISPATCH_FLOAT_HALF_AND_BFLOAT( |
| 204 | + cos.scalar_type(), |
| 205 | + "get_cos_and_sin", |
| 206 | + apply_get_cos_and_sin<scalar_t>( |
| 207 | + cos_cache, |
| 208 | + sin_cache, |
| 209 | + cos, |
| 210 | + sin, |
| 211 | + sequence_lengths, |
| 212 | + max_seq_len_in_batch, |
| 213 | + is_prompts |
| 214 | + );) |
| 215 | +} |
0 commit comments