-
Notifications
You must be signed in to change notification settings - Fork 192
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7904899
commit 4f3e977
Showing
29 changed files
with
1,590 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
/* | ||
Adapted from NVIDIA FasterTransformer: | ||
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu | ||
*/ | ||
|
||
#include <torch/extension.h> | ||
#include <cuda_fp16.h> | ||
#include "reduction.cuh" | ||
#include "layernorm.h" | ||
#include <cuda_runtime.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
static inline __device__ float to_float(half src) | ||
{ | ||
return __half2float(src); | ||
} | ||
|
||
static inline __device__ float to_float(float src) | ||
{ | ||
return src; | ||
} | ||
|
||
template<typename T> | ||
__global__ void generalT5LayerNorm( | ||
const T* __restrict input, const T* __restrict gamma, T* output, const float layernorm_eps, int m, int n) | ||
{ | ||
// layernorm module in the T5 style No bias and no subtraction of mean. | ||
const int tid = threadIdx.x; | ||
|
||
__shared__ float s_variance; | ||
float variance = 0.0f; | ||
|
||
float local_var_sum = 0.0f; | ||
for (int i = tid; i < n; i += blockDim.x) { | ||
float diff = to_float(__ldg(&input[blockIdx.x * n + i])); | ||
local_var_sum += diff * diff; | ||
} | ||
variance = blockReduceSum(local_var_sum); | ||
|
||
if (threadIdx.x == 0) { | ||
s_variance = rsqrtf(variance / (float)n + layernorm_eps); | ||
} | ||
__syncthreads(); | ||
|
||
for (int i = tid; i < n; i += blockDim.x) { | ||
output[blockIdx.x * n + i] = | ||
clamp_inf_for_half<T>((to_float(input[blockIdx.x * n + i]) * s_variance) * to_float(__ldg(&gamma[i]))); | ||
} | ||
} | ||
|
||
|
||
template<typename T> | ||
void invokeGeneralT5LayerNorm(T* out, | ||
const T* input, | ||
const T* gamma, | ||
// const T* beta, | ||
const float layernorm_eps, | ||
const int m, | ||
const int n) | ||
{ | ||
dim3 grid(m); | ||
dim3 block(min(n, 1024)); | ||
|
||
/* For general cases, n is equal to hidden_units, e.g., 512/1024. | ||
Since we have warp shuffle inside the code, block.x % 32 should be 0. | ||
*/ | ||
if (n % 32 != 0) { | ||
block.x = 1024; | ||
} | ||
|
||
block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x | ||
|
||
/* should pay attention to the rsqrt precision*/ | ||
generalT5LayerNorm<T><<<grid, block>>>(input, gamma, out, layernorm_eps, m, n); // For gpt-3 | ||
} | ||
|
||
template void invokeGeneralT5LayerNorm(half* out, | ||
const half* input, | ||
const half* gamma, | ||
// const half* beta, | ||
const float layernorm_eps, | ||
const int m, | ||
const int n); | ||
|
||
template void invokeGeneralT5LayerNorm(float* out, | ||
const float* input, | ||
const float* gamma, | ||
// const half* beta, | ||
const float layernorm_eps, | ||
const int m, | ||
const int n); | ||
|
||
|
||
|
||
// input b, n, c | ||
void layernorm_forward_cuda( | ||
torch::Tensor _input, | ||
torch::Tensor _gamma, | ||
torch::Tensor _out, | ||
float eps) | ||
{ | ||
int m = _input.size(0) * _input.size(1); | ||
int n = _input.size(2); | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_input)); | ||
|
||
auto input = reinterpret_cast<half*>(_input.data_ptr<at::Half>()); | ||
auto gamma = reinterpret_cast<half*>(_gamma.data_ptr<at::Half>()); | ||
auto out = reinterpret_cast<half*>(_out.data_ptr<at::Half>()); | ||
|
||
invokeGeneralT5LayerNorm(out, input, gamma, eps, m, n); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#include <torch/extension.h> | ||
|
||
void layernorm_forward_cuda(torch::Tensor _input, torch::Tensor _gamma, torch::Tensor _out, float eps); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* | ||
Adapted from NVIDIA FasterTransformer: | ||
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/reduce_kernel_utils.cuh | ||
*/ | ||
|
||
#pragma once | ||
#include <assert.h> | ||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) | ||
#include <cooperative_groups/reduce.h> | ||
#else | ||
#include <cooperative_groups.h> | ||
#endif | ||
#include <cuda_fp16.h> | ||
#include <cuda_runtime.h> | ||
#include <float.h> | ||
#include <type_traits> | ||
|
||
static const float HALF_FLT_MAX = 65504.F; | ||
#define FINAL_MASK 0xffffffff | ||
|
||
|
||
template<typename T> | ||
inline __device__ T add(T a, T b) { | ||
return a + b; | ||
} | ||
|
||
template<> | ||
inline __device__ half2 add(half2 a, half2 b) { | ||
return __hadd2(a, b); | ||
} | ||
|
||
template<> | ||
inline __device__ half add(half a, half b) { | ||
return __hadd(a, b); | ||
} | ||
|
||
template<typename T> | ||
__inline__ __device__ T warpReduceSum(T val) | ||
{ | ||
#pragma unroll | ||
for (int mask = 16; mask > 0; mask >>= 1) | ||
val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 | ||
return val; | ||
} | ||
|
||
/* Calculate the sum of all elements in a block */ | ||
template<typename T> | ||
__inline__ __device__ T blockReduceSum(T val) | ||
{ | ||
static __shared__ T shared[32]; | ||
int lane = threadIdx.x & 0x1f; | ||
int wid = threadIdx.x >> 5; | ||
|
||
val = warpReduceSum<T>(val); | ||
|
||
if (lane == 0) | ||
shared[wid] = val; | ||
|
||
__syncthreads(); | ||
|
||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent | ||
// blockDim.x is not divided by 32 | ||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); | ||
val = warpReduceSum<T>(val); | ||
|
||
return val; | ||
} | ||
|
||
|
||
template<typename T> | ||
__device__ __forceinline__ T clamp_inf_for_half(const float input) | ||
{ | ||
return input; | ||
} | ||
|
||
template<> | ||
__device__ __forceinline__ half clamp_inf_for_half(const float input) | ||
{ | ||
// clamp inf values to enable fp16 training | ||
return input > 0.0f ? __float2half(min(input, HALF_FLT_MAX - 1000)) : __float2half(max(input, -HALF_FLT_MAX + 1000)); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#pragma once | ||
#include <torch/extension.h> | ||
|
||
void rotary_embedding_neox( | ||
torch::Tensor& positions, | ||
torch::Tensor& query, | ||
torch::Tensor& key, | ||
int head_size, | ||
torch::Tensor& cos_sin_cache); |
88 changes: 88 additions & 0 deletions
88
awq/kernels/csrc/position_embedding/pos_encoding_kernels.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
/* | ||
Adapted from the VLLM project: | ||
https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu | ||
*/ | ||
|
||
#include <torch/extension.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include "pos_encoding.h" | ||
|
||
template<typename scalar_t> | ||
__global__ void rotary_embedding_neox_kernel( | ||
const int64_t* __restrict__ positions, // [num_tokens] | ||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] | ||
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] | ||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] | ||
const int rot_dim, | ||
const int stride, | ||
const int num_heads, | ||
const int head_size) { | ||
// Each thread block is responsible for one token. | ||
const int token_idx = blockIdx.x; | ||
int64_t pos = positions[token_idx]; | ||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; | ||
|
||
const int embed_dim = rot_dim / 2; | ||
const int n = num_heads * embed_dim; | ||
for (int i = threadIdx.x; i < n; i += blockDim.x) { | ||
const int head_idx = i / embed_dim; | ||
const int token_head = token_idx * stride + head_idx * head_size; | ||
|
||
const int rot_offset = i % embed_dim; | ||
const int x_index = rot_offset; | ||
const int y_index = embed_dim + rot_offset; | ||
|
||
const int out_x = token_idx * stride + head_idx * head_size + x_index; | ||
const int out_y = token_idx * stride + head_idx * head_size + y_index; | ||
|
||
const scalar_t cos = __ldg(cache_ptr + x_index); | ||
const scalar_t sin = __ldg(cache_ptr + y_index); | ||
|
||
const scalar_t q_x = query[token_head + x_index]; | ||
const scalar_t q_y = query[token_head + y_index]; | ||
query[out_x] = q_x * cos - q_y * sin; | ||
query[out_y] = q_y * cos + q_x * sin; | ||
|
||
const scalar_t k_x = key[token_head + x_index]; | ||
const scalar_t k_y = key[token_head + y_index]; | ||
key[out_x] = k_x * cos - k_y * sin; | ||
key[out_y] = k_y * cos + k_x * sin; | ||
} | ||
} | ||
|
||
void rotary_embedding_neox( | ||
torch::Tensor& positions, // [b, num_tokens] | ||
torch::Tensor& query, // [b, num_tokens, 1, num_heads, head_size] | ||
torch::Tensor& key, // [b, num_tokens, 1, num_heads, head_size] | ||
int head_size, | ||
torch::Tensor& cos_sin_cache) // [max_position, rot_dim] | ||
{ | ||
int num_tokens = query.size(0) * query.size(1); | ||
int rot_dim = cos_sin_cache.size(1); | ||
int num_heads = query.size(-2); | ||
int stride = num_heads * head_size; | ||
// TORCH_CHECK(stride == key.stride(0)); | ||
|
||
dim3 grid(num_tokens); | ||
dim3 block(std::min(num_heads * rot_dim / 2, 512)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
AT_DISPATCH_FLOATING_TYPES_AND2( | ||
at::ScalarType::Half, | ||
at::ScalarType::BFloat16, | ||
query.scalar_type(), | ||
"rotary_embedding_neox", | ||
[&] { | ||
rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||
positions.data_ptr<int64_t>(), | ||
query.data_ptr<scalar_t>(), | ||
key.data_ptr<scalar_t>(), | ||
cos_sin_cache.data_ptr<scalar_t>(), | ||
rot_dim, | ||
stride, | ||
num_heads, | ||
head_size); | ||
}); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#include <pybind11/pybind11.h> | ||
#include <torch/extension.h> | ||
#include "layernorm/layernorm.h" | ||
#include "quantization/gemm_cuda.h" | ||
#include "position_embedding/pos_encoding.h" | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||
{ | ||
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel"); | ||
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); | ||
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); | ||
} |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.