@@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
1717namespace vllm {
1818namespace awq {
1919
20- // Pack two half values.
21- static inline __device__ __host__ unsigned __pack_half2 (const half x,
22- const half y) {
23- unsigned v0 = *((unsigned short *)&x);
24- unsigned v1 = *((unsigned short *)&y);
25- return (v1 << 16 ) | v0;
26- }
27-
2820template <int N>
2921__global__ void __launch_bounds__ (64 )
3022 gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
@@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
4234 __shared__ half A_shared[16 * (32 + 8 )];
4335 __shared__ half B_shared[32 * (N + 8 )];
4436
45- __shared__ half scaling_factors_shared[N];
46- __shared__ half zeros_shared[N];
47-
4837 int j_factors1 = ((OC + N - 1 ) / N);
49- int blockIdx_x = 0 ;
5038 int blockIdx_y = blockIdx .x % ((M + 16 - 1 ) / 16 * j_factors1);
5139 int blockIdx_z = blockIdx .x / ((M + 16 - 1 ) / 16 * j_factors1);
5240
@@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
6048
6149 static constexpr int row_stride_warp = 32 * 8 / 32 ;
6250 static constexpr int row_stride = 2 * 32 * 8 / N;
63- bool ld_zero_flag = (threadIdx .y * 32 + threadIdx .x ) * 8 < N;
6451 // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
6552 bool ld_A_flag =
6653 (blockIdx_y / j_factors1 * 16 + threadIdx .y * row_stride_warp +
@@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
145132 uint32_t B_loaded =
146133 *(uint32_t *)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8 ));
147134 uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2 (B_loaded);
148- // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
149- // 8)) * 8);
150135
151- // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
152- // % (cta_N / 8)) * 8);
153136 // - zero and * scale
154137 // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
155138 // q * scale - zero * scale.
@@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
367350__global__ void __launch_bounds__ (64 )
368351 dequantize_weights(int * __restrict__ B, half* __restrict__ scaling_factors,
369352 int * __restrict__ zeros, half* __restrict__ C, int G) {
370- int j_factors1 = 4 ;
371- int row_stride2 = 4 ;
372- int split_k_iters = 1 ;
373353 static constexpr uint32_t ZERO = 0x0 ;
374354 half B_shared[32 * (128 + 8 )];
375355
376356 half* B_shared_ptr2 = B_shared;
377357
378- half B_shared_warp[32 ];
379- int OC = 512 ;
380-
381358 int N = blockDim .x * gridDim .x ; // 2
382359 int col = (blockIdx .x * blockDim .x + threadIdx .x );
383360 int row = blockIdx .y * blockDim .y + threadIdx .y ;
0 commit comments