Skip to content

Commit

Permalink
fix dequantize awq
Browse files Browse the repository at this point in the history
  • Loading branch information
minhthuc2502 committed Jun 20, 2024
1 parent 0dc14ed commit 44e6ff1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 23 deletions.
9 changes: 8 additions & 1 deletion src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,14 @@ namespace ctranslate2 {
StorageView weight_dequant(input.dtype(), input.device());
ops::DequantizeAwq dequantize_awq_op;
dequantize_awq_op(*weight, *qscale, *_qzero, weight_dequant);
_gemm_op(input, weight_dequant, output, nullptr, bias);
ops::Gemm gemm_op(/*alpha=*/1,
/*beta=*/0,
/*trans_a=*/false,
/*trans_b=*/false,
/*a_is_packed=*/false,
/*b_is_packed*/false,
_activation_type);
gemm_op(input, weight_dequant, output, nullptr, bias);
} else {
ops::GemmAwq gemm_awq_op(/*alpha=*/1, /*beta=*/0, /*trans_a=*/false, /*trans_b=*/false,
/*a_is_packed=*/false, /*b_is_packed=*/false, _activation_type);
Expand Down
25 changes: 4 additions & 21 deletions src/ops/awq/dequantize_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,11 @@ namespace ctranslate2 {
zeros = zeros + blockIdx.z * in_c * out_c / G / 8;
C = C + blockIdx.z * in_c * out_c;
}
int j_factors1 = 4;
int row_stride2 = 4;
int split_k_iters = 1;
static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)];

half* B_shared_ptr2 = B_shared;

half B_shared_warp[32];
int OC = 512;

int N = blockDim.x * gridDim.x; // 2
int col = (blockIdx.x * blockDim.x + threadIdx.x);
int row = blockIdx.y * blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -73,35 +67,24 @@ namespace ctranslate2 {
const StorageView& scale,
const StorageView& zero,
StorageView& output) const {
dim_t in_c = input.size() == 2 ? input.dim(0) : input.dim(1);
dim_t qout_c = input.size() == 2 ? input.dim(1) : input.dim(2);
int num_experts = input.size() == 2 ? 1 : input.dim(0);
dim_t in_c = input.rank() == 2 ? input.dim(0) : input.dim(1);
dim_t qout_c = input.rank() == 2 ? input.dim(1) : input.dim(2);
int num_experts = input.rank() == 2 ? 1 : input.dim(0);
int out_c = qout_c * 8;
int G = in_c / (input.size() == 2 ? scale.dim(0) : scale.dim(1));
int G = in_c / (input.rank() == 2 ? scale.dim(0) : scale.dim(1));

int x_thread = 0 /*thx*/;
int y_thread = 0 /*thy*/;

int x_blocks = 1;
int y_blocks = 1;
//if (thx==0) {
x_thread = qout_c;
//}
//if (thy==0) {
y_thread = in_c;
//}

//int dbg_ = true;
//if (thx==0 && thy==0) {
int dbg = false;
x_thread = 8;
y_thread = 8;
x_blocks = (int)(qout_c / 8);
y_blocks = (int)(in_c / 8);
//}
//dbg = dbg && dbg_;

//auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
if (num_experts == 1) {
output.resize({in_c, out_c});
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/ops/awq/gemv_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace ctranslate2 {

int j_factors1 = ((OC + 64 - 1) / 64);

int blockIdx_x = 0;
//int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 128 - 1) / 128 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 128 - 1) / 128 * j_factors1);

Expand Down

0 comments on commit 44e6ff1

Please sign in to comment.