Skip to content

Commit

Permalink
[Major] Add TinyChat and demo.
Browse files Browse the repository at this point in the history
  • Loading branch information
kentang-mit committed Jul 24, 2023
1 parent 7904899 commit 4f3e977
Show file tree
Hide file tree
Showing 29 changed files with 1,590 additions and 101 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ The current release supports:
- Efficient CUDA kernel implementation for fast inference (support context and decoding stage).
- Examples on 4-bit inference of an instruction-tuned model (Vicuna) and multi-modal LM (LLaVA).

![TinyChat on RTX 4090: W4A16 is 2.3x faster than FP16](./tinychat/figures/4090_example.gif)

Check out [TinyChat](tinychat), which delievers 2.3x faster inference performance for the **LLaMA-2** chatbot!


## News
- [2023/07] 🔥 We released TinyChat, an efficient and minimal chatbot interface based on AWQ. LLama-2-chat models are supported! Check out our implementation [here](tinychat).
- [2023/07] 🔥 We added AWQ support and pre-computed search results for Llama-2 models (7B & 13B). Checkout our model zoo [here](https://huggingface.co/datasets/mit-han-lab/awq-model-zoo)!
- [2023/07] We extended the support for more LLM models including MPT, Falcon, and BLOOM.

Expand Down Expand Up @@ -40,7 +46,7 @@ pip install --upgrade pip # enable PEP 660 support
pip install -e .
```

3. Install efficient W4A16 (4-bit weight, 16-bit activation) CUDA kernel
3. Install efficient W4A16 (4-bit weight, 16-bit activation) CUDA kernel and optimized FP16 kernels (e.g. layernorm, positional encodings).
```
cd awq/kernels
python setup.py install
Expand Down
113 changes: 113 additions & 0 deletions awq/kernels/csrc/layernorm/layernorm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
Adapted from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
*/

#include <torch/extension.h>
#include <cuda_fp16.h>
#include "reduction.cuh"
#include "layernorm.h"
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>

static inline __device__ float to_float(half src)
{
return __half2float(src);
}

static inline __device__ float to_float(float src)
{
return src;
}

template<typename T>
__global__ void generalT5LayerNorm(
const T* __restrict input, const T* __restrict gamma, T* output, const float layernorm_eps, int m, int n)
{
// layernorm module in the T5 style No bias and no subtraction of mean.
const int tid = threadIdx.x;

__shared__ float s_variance;
float variance = 0.0f;

float local_var_sum = 0.0f;
for (int i = tid; i < n; i += blockDim.x) {
float diff = to_float(__ldg(&input[blockIdx.x * n + i]));
local_var_sum += diff * diff;
}
variance = blockReduceSum(local_var_sum);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / (float)n + layernorm_eps);
}
__syncthreads();

for (int i = tid; i < n; i += blockDim.x) {
output[blockIdx.x * n + i] =
clamp_inf_for_half<T>((to_float(input[blockIdx.x * n + i]) * s_variance) * to_float(__ldg(&gamma[i])));
}
}


template<typename T>
void invokeGeneralT5LayerNorm(T* out,
const T* input,
const T* gamma,
// const T* beta,
const float layernorm_eps,
const int m,
const int n)
{
dim3 grid(m);
dim3 block(min(n, 1024));

/* For general cases, n is equal to hidden_units, e.g., 512/1024.
Since we have warp shuffle inside the code, block.x % 32 should be 0.
*/
if (n % 32 != 0) {
block.x = 1024;
}

block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x

/* should pay attention to the rsqrt precision*/
generalT5LayerNorm<T><<<grid, block>>>(input, gamma, out, layernorm_eps, m, n); // For gpt-3
}

template void invokeGeneralT5LayerNorm(half* out,
const half* input,
const half* gamma,
// const half* beta,
const float layernorm_eps,
const int m,
const int n);

template void invokeGeneralT5LayerNorm(float* out,
const float* input,
const float* gamma,
// const half* beta,
const float layernorm_eps,
const int m,
const int n);



// input b, n, c
void layernorm_forward_cuda(
torch::Tensor _input,
torch::Tensor _gamma,
torch::Tensor _out,
float eps)
{
int m = _input.size(0) * _input.size(1);
int n = _input.size(2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_input));

auto input = reinterpret_cast<half*>(_input.data_ptr<at::Half>());
auto gamma = reinterpret_cast<half*>(_gamma.data_ptr<at::Half>());
auto out = reinterpret_cast<half*>(_out.data_ptr<at::Half>());

invokeGeneralT5LayerNorm(out, input, gamma, eps, m, n);
}
3 changes: 3 additions & 0 deletions awq/kernels/csrc/layernorm/layernorm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include <torch/extension.h>

void layernorm_forward_cuda(torch::Tensor _input, torch::Tensor _gamma, torch::Tensor _out, float eps);
82 changes: 82 additions & 0 deletions awq/kernels/csrc/layernorm/reduction.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
Adapted from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/reduce_kernel_utils.cuh
*/

#pragma once
#include <assert.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include <cooperative_groups/reduce.h>
#else
#include <cooperative_groups.h>
#endif
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <float.h>
#include <type_traits>

static const float HALF_FLT_MAX = 65504.F;
#define FINAL_MASK 0xffffffff


template<typename T>
inline __device__ T add(T a, T b) {
return a + b;
}

template<>
inline __device__ half2 add(half2 a, half2 b) {
return __hadd2(a, b);
}

template<>
inline __device__ half add(half a, half b) {
return __hadd(a, b);
}

template<typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
return val;
}

/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;

val = warpReduceSum<T>(val);

if (lane == 0)
shared[wid] = val;

__syncthreads();

// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);

return val;
}


template<typename T>
__device__ __forceinline__ T clamp_inf_for_half(const float input)
{
return input;
}

template<>
__device__ __forceinline__ half clamp_inf_for_half(const float input)
{
// clamp inf values to enable fp16 training
return input > 0.0f ? __float2half(min(input, HALF_FLT_MAX - 1000)) : __float2half(max(input, -HALF_FLT_MAX + 1000));
}
9 changes: 9 additions & 0 deletions awq/kernels/csrc/position_embedding/pos_encoding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once
#include <torch/extension.h>

void rotary_embedding_neox(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache);
88 changes: 88 additions & 0 deletions awq/kernels/csrc/position_embedding/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
Adapted from the VLLM project:
https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
*/

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"

template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int stride,
const int num_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size;

const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;

const int out_x = token_idx * stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index;

const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);

const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;

const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index];
key[out_x] = k_x * cos - k_y * sin;
key[out_y] = k_y * cos + k_x * sin;
}
}

void rotary_embedding_neox(
torch::Tensor& positions, // [b, num_tokens]
torch::Tensor& query, // [b, num_tokens, 1, num_heads, head_size]
torch::Tensor& key, // [b, num_tokens, 1, num_heads, head_size]
int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{
int num_tokens = query.size(0) * query.size(1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-2);
int stride = num_heads * head_size;
// TORCH_CHECK(stride == key.stride(0));

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(),
"rotary_embedding_neox",
[&] {
rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
stride,
num_heads,
head_size);
});
}

12 changes: 12 additions & 0 deletions awq/kernels/csrc/pybind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "position_embedding/pos_encoding.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
}
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/


#include <torch/extension.h>
#include "gemm_cuda.h"
#include "dequantize.cuh"
Expand Down Expand Up @@ -107,7 +119,6 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);

for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// TODO: Shang: double check how to get 8.

// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
Expand Down Expand Up @@ -465,4 +476,3 @@ torch::Tensor gemm_forward_cuda(
}
return _out_feats.sum(0);
}

8 changes: 0 additions & 8 deletions awq/kernels/pybind.cpp

This file was deleted.

11 changes: 8 additions & 3 deletions awq/kernels/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
}

setup(
name="f16s4_gemm",
name="awq_inference_engine",
packages=find_packages(),
ext_modules=[
CUDAExtension(
name="f16s4_gemm",
sources=["pybind.cpp", "gemm_cuda_gen.cu"],
name="awq_inference_engine",
sources=[
"csrc/pybind.cpp",
"csrc/quantization/gemm_cuda_gen.cu",
"csrc/layernorm/layernorm.cu",
"csrc/position_embedding/pos_encoding_kernels.cu"
],
extra_compile_args=extra_compile_args,
),
],
Expand Down
Loading

0 comments on commit 4f3e977

Please sign in to comment.