Skip to content

Commit

Permalink
Fix overflow in awq kernel (vllm-project#1295)
Browse files Browse the repository at this point in the history
Co-authored-by: 楚天翔 <tianxiang.ctx@alibaba-inc.com>
  • Loading branch information
chu-tianxiang and chu-tianxiang authored Oct 11, 2023
1 parent 8285736 commit 980dd4a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions csrc/quantization/awq/gemm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
+ (((int)threadIdx.x) % (128 / 8)) * 8;

half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2;
Expand Down Expand Up @@ -323,7 +323,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
+ (((int)threadIdx.x) % (64 / 8)) * 8;

half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;
Expand Down

0 comments on commit 980dd4a

Please sign in to comment.