Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ & AWQ Fused MOE #2761

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2d970c5
Add kernel
chu-tianxiang Feb 4, 2024
281354a
Add group gemm kernel for gptq
chu-tianxiang Feb 5, 2024
2a1c106
Add dequant kernel
chu-tianxiang Feb 5, 2024
e9846d0
Add awq supprt
chu-tianxiang Feb 6, 2024
a9d65a9
Add test
chu-tianxiang Feb 6, 2024
7dea006
format
chu-tianxiang Feb 6, 2024
dfbb034
Merge main and fix problem in kernel
chu-tianxiang Feb 7, 2024
46d15fb
format
chu-tianxiang Feb 7, 2024
78e6e70
Merge main and fix conflicts
chu-tianxiang Feb 16, 2024
6b3e23e
Fix unit test
chu-tianxiang Feb 23, 2024
238f544
Merge branch 'main' into moe_exp
chu-tianxiang Feb 24, 2024
d43445e
Add guard for awq unit test
chu-tianxiang Feb 24, 2024
2c68478
Fix format
chu-tianxiang Feb 24, 2024
2c27dcc
test
chu-tianxiang Feb 24, 2024
c1b98ef
merge main
chu-tianxiang Feb 27, 2024
68d34af
Fix import
chu-tianxiang Feb 27, 2024
6e69101
Merge branch 'main' into moe_exp
chu-tianxiang Feb 27, 2024
d956844
fix format
chu-tianxiang Feb 27, 2024
f19ddfb
Merge main and fix conflicts
chu-tianxiang Feb 29, 2024
7a4ba90
Adapt gptq dequant to 3/8-bit
chu-tianxiang Mar 1, 2024
2fe491d
Merge main branch
chu-tianxiang Mar 3, 2024
4ef69d5
Fix marlin
chu-tianxiang Mar 3, 2024
7a11506
Merge main branch and fix conflicts
chu-tianxiang Mar 12, 2024
9d6f7d1
Fix format check
chu-tianxiang Mar 12, 2024
d08c4fa
Merge main
chu-tianxiang Mar 29, 2024
4faebc3
Fix isort
chu-tianxiang Mar 29, 2024
e8b2127
Fix format
chu-tianxiang Mar 29, 2024
1922e83
Replace expert parallel with tensor parallel
chu-tianxiang Apr 8, 2024
8bc089f
Fix typo
chu-tianxiang Apr 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add dequant kernel
  • Loading branch information
chu-tianxiang committed Feb 5, 2024
commit 2a1c106119b89a2afca6179751af13118f73763a
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ torch::Tensor group_gptq_gemm(
bool use_exllama
);

torch::Tensor dequant_gptq(
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama
);

void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
Expand Down
3 changes: 2 additions & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("group_gptq_gemm", &group_gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("group_gptq_gemm", &group_gptq_gemm, "Grouped Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("dequant_gptq", &dequant_gptq, "Dequantize gptq weight to half");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"moe_align_block_size",
Expand Down
121 changes: 104 additions & 17 deletions csrc/quantization/gptq/q_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,14 @@ __global__ void reconstruct_exllama_kernel
half* __restrict__ b
)
{
if (blockIdx.z > 0){
b_q_weight = b_q_weight + blockIdx.z * size_k * size_n / 8;
b_gptq_scales = b_gptq_scales + blockIdx.z * groups * size_n;
b_gptq_qzeros = b_gptq_qzeros + blockIdx.z * groups * size_n / 8;
b_q_perm = b_q_perm + blockIdx.z * size_k;
b = b + blockIdx.z * size_k * size_n;
}

MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
Expand Down Expand Up @@ -426,14 +434,16 @@ void reconstruct_exllama
half* out,
int height,
int width,
int groups
int groups,
int num_experts
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
gridDim.z = num_experts;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
Expand Down Expand Up @@ -597,6 +607,13 @@ __global__ void reconstruct_gptq_kernel
half* __restrict__ out
)
{
if (blockIdx.z > 0){
w = w + blockIdx.z * height * width / 8;
w_scales = w_scales + blockIdx.z * group * width;
w_zeros = w_zeros + blockIdx.z * group * width / 8;
g_idx = g_idx + blockIdx.z * height;
out = out + blockIdx.z * height * width;
}
// Start of block

int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
Expand Down Expand Up @@ -634,14 +651,16 @@ void reconstruct_gptq
half* out,
int height,
int width,
int groups
int groups,
int num_experts
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.y = DIVIDE(height, 8);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
gridDim.z = num_experts;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
(
Expand Down Expand Up @@ -678,12 +697,12 @@ void gemm_half_q_half_cuda
// Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) {
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
size_k, size_n, groups);
size_k, size_n, groups, 1);
}
else
{
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups);
temp_dq, size_k, size_n, groups, 1);
}

const half alpha = __float2half(1.0f);
Expand Down Expand Up @@ -836,7 +855,7 @@ __global__ void group_gemm_half_q_half_gptq_kernel
const int size_k,
const int groups,
const int* __restrict__ b_q_perm,
const half* __restrict__ topk_weights,
const float* __restrict__ topk_weights,
const int* __restrict__ sorted_token_ids_ptr,
const int* __restrict__ expert_ids_ptr,
const int* __restrict__ num_tokens_post_padded,
Expand All @@ -846,7 +865,7 @@ __global__ void group_gemm_half_q_half_gptq_kernel
{
int num_tokens = *num_tokens_post_padded;
int offset_m = blockIdx.y * m_count;
if (offset_m >= num_tokens) return
if (offset_m >= num_tokens) return;

int expert_id = expert_ids_ptr[blockIdx.y];
b_q_weight = b_q_weight + size_k * size_n / 8 * expert_id;
Expand Down Expand Up @@ -976,14 +995,15 @@ __global__ void group_gemm_half_q_half_gptq_kernel

for (int m = 0; m < valid_count; m++)
{
if (topk_weights) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
block_c[m][j] = block_c[m][j] * topk_weights[token_a[m]];
}
}
half2 *out = (half2*) c_.item_ptr(token_a[m], n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
if (topk_weights) {
half2 topk_weight = __half2half2(topk_weights[token_a[m]]);
result01 = __hmul2(result01, topk_weight);
result23 = __hmul2(result23, topk_weight);
}
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
Expand All @@ -997,7 +1017,7 @@ void group_gemm_half_q_half_cuda
const half* b_gptq_scales,
const int* b_q_perm,
half* c,
const half* __restrict__ topk_weights,
const float* __restrict__ topk_weights,
const int* __restrict__ sorted_token_ids_ptr,
const int* __restrict__ expert_ids_ptr,
const int* __restrict__ num_tokens_post_padded,
Expand Down Expand Up @@ -1051,7 +1071,7 @@ __global__ void group_gemm_half_q_half_alt_kernel(
int height,
int width,
int groups,
const half* __restrict__ topk_weights,
const float* __restrict__ topk_weights,
const int* __restrict__ sorted_token_ids_ptr,
const int* __restrict__ expert_ids_ptr,
const int* __restrict__ num_tokens_post_padded,
Expand Down Expand Up @@ -1155,7 +1175,7 @@ __global__ void group_gemm_half_q_half_alt_kernel(
}
for (int m = 0; m < b_end; m++) {
if (topk_weights) {
res[m] = __hmul(res[m], topk_weights[token_a[m]]);
res[m] = __float2half(__half2float(res[m]) * topk_weights[token_a[m]]);
}
atomicAdd(&mul[token_a[m] * width + w], res[m]);
}
Expand All @@ -1170,7 +1190,7 @@ void group_gemm_half_q_half_alt
const half* b_gptq_scales,
const int* b_g_idx,
half* c,
const half* __restrict__ topk_weights,
const float* __restrict__ topk_weights,
const int* __restrict__ sorted_token_ids_ptr,
const int* __restrict__ expert_ids_ptr,
const int* __restrict__ num_tokens_post_padded,
Expand Down Expand Up @@ -1221,7 +1241,7 @@ void group_gemm_half_q_half_cuda
const half* b_gptq_scales,
const int* b_g_idx,
half* c,
const half* __restrict__ topk_weights,
const float* __restrict__ topk_weights,
const int* __restrict__ sorted_token_ids_ptr,
const int* __restrict__ expert_ids_ptr,
const int* __restrict__ num_tokens_post_padded,
Expand Down Expand Up @@ -1251,6 +1271,31 @@ void group_gemm_half_q_half_cuda
}
}

void dequant_gptq_cuda
(
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_g_idx,
half* temp_dq,
int size_k,
int size_n,
int groups,
int num_experts,
bool use_exllama
)
{
if (use_exllama) {
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
size_k, size_n, groups, num_experts);
}
else
{
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, num_experts);
}
}

} // namespace gptq
} // namespace vllm

Expand Down Expand Up @@ -1331,7 +1376,7 @@ torch::Tensor group_gptq_gemm
(const half*) b_gptq_scales.data_ptr(),
b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
(half*) c.data_ptr(),
mul_weights ? (const half*) topk_weights.data_ptr() : NULL,
mul_weights ? (const float*) topk_weights.data_ptr() : NULL,
(const int*) sorted_token_ids_ptr.data_ptr(),
(const int*) expert_ids_ptr.data_ptr(),
(const int*) num_tokens_post_padded.data_ptr(),
Expand All @@ -1345,4 +1390,46 @@ torch::Tensor group_gptq_gemm
use_exllama
);
return c;
}

torch::Tensor dequant_gptq
(
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_gptq_scales));
auto options = torch::TensorOptions().dtype(b_gptq_scales.dtype()).device(b_gptq_scales.device());

at::Tensor temp_dq;
int num_experts;
int size_k;
int size_n;
int groups;
// moe
if (b_q_weight.dim() == 3) {
temp_dq = torch::empty({b_q_weight.size(0), b_q_weight.size(1) * 8, b_q_weight.size(2)}, options);
num_experts = b_q_weight.size(0);
size_k = b_q_weight.size(1) * 8;
size_n = b_q_weight.size(2);
groups = b_gptq_scales.size(1);
} else
{
temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options);
num_experts = 1;
size_k = b_q_weight.size(0) * 8;
size_n = b_q_weight.size(1);
groups = b_gptq_scales.size(0);
}
vllm::gptq::dequant_gptq_cuda(
(const uint32_t*) b_q_weight.data_ptr(),
(const uint32_t*)b_gptq_qzeros.data_ptr(),
(const half*) b_gptq_scales.data_ptr(),
b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
(half*) temp_dq.data_ptr(),
size_k, size_n, groups,
num_experts, use_exllama);
return temp_dq;
}
12 changes: 6 additions & 6 deletions vllm/model_executor/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def moe_align_block_size(
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.

This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.
This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions align correctly.

Example:
Expand All @@ -151,7 +151,7 @@ def moe_align_block_size(
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
- After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.
"""
Expand Down Expand Up @@ -218,23 +218,23 @@ def fused_moe(hidden_states: torch.Tensor,
inplace=False):
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The weights for the top-k selected experts.
- topk_ids (torch.Tensor): The indices of the top-k selected experts.
- inplace (bool): If True, perform the operation in-place. Defaults to False.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
# assert w1.is_contiguous(), "Expert weights1 must be contiguous"
# assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
Expand Down
Loading
Loading