@@ -59,14 +59,16 @@ __global__ void Marlin(
5959 const int4 * __restrict__ A, // fp16 input matrix of shape mxk
6060 const int4 * __restrict__ B, // 4bit quantized weight matrix of shape kxn
6161 int4 * __restrict__ C, // fp16 output buffer of shape mxn
62+ int4 * __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
6263 const int4 * __restrict__ scales_ptr, // fp16 quantization scales of shape
6364 // (k/groupsize)xn
6465 const int * __restrict__ g_idx, // int32 group indices of shape k
65- int num_groups, // number of scale groups per output channel
66- int prob_m, // batch dimension m
67- int prob_n, // output dimension n
68- int prob_k, // reduction dimension k
69- int * locks // extra global storage for barrier synchronization
66+ int num_groups, // number of scale groups per output channel
67+ int prob_m, // batch dimension m
68+ int prob_n, // output dimension n
69+ int prob_k, // reduction dimension k
70+ int * locks, // extra global storage for barrier synchronization
71+ bool use_fp32_reduce // whether to use fp32 global reduce
7072) {}
7173
7274} // namespace gptq_marlin
@@ -532,16 +534,18 @@ __global__ void Marlin(
532534 const int4 * __restrict__ A, // fp16 input matrix of shape mxk
533535 const int4 * __restrict__ B, // 4bit quantized weight matrix of shape kxn
534536 int4 * __restrict__ C, // fp16 output buffer of shape mxn
537+ int4 * __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
535538 const int4 * __restrict__ scales_ptr, // fp16 quantization scales of shape
536539 // (k/groupsize)xn
537540 const int4 * __restrict__ zp_ptr, // 4bit packed zero-points of shape
538541 // (k/groupsize)x(n/pack_factor)
539542 const int * __restrict__ g_idx, // int32 group indices of shape k
540- int num_groups, // number of scale groups per output channel
541- int prob_m, // batch dimension m
542- int prob_n, // output dimension n
543- int prob_k, // reduction dimension k
544- int * locks // extra global storage for barrier synchronization
543+ int num_groups, // number of scale groups per output channel
544+ int prob_m, // batch dimension m
545+ int prob_n, // output dimension n
546+ int prob_k, // reduction dimension k
547+ int * locks, // extra global storage for barrier synchronization
548+ bool use_fp32_reduce // whether to use fp32 global reduce
545549) {
546550 // Each threadblock processes one "stripe" of the B matrix with (roughly) the
547551 // same size, which might involve multiple column "slices" (of width 16 *
@@ -595,13 +599,16 @@ __global__ void Marlin(
595599 int slice_idx; // index of threadblock in current slice; numbered bottom to
596600 // top
597601
602+ int par_id = 0 ;
603+
598604 // We can easily implement parallel problem execution by just remapping
599605 // indices and advancing global pointers
600606 if (slice_col_par >= n_tiles) {
601607 A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8 ;
602608 C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8 ;
603609 locks += (slice_col_par / n_tiles) * n_tiles;
604610 slice_col = slice_col_par % n_tiles;
611+ par_id = slice_col_par / n_tiles;
605612 }
606613
607614 // Compute all information about the current slice which is required for
@@ -632,6 +639,7 @@ __global__ void Marlin(
632639 C += 16 * thread_m_blocks * prob_n / 8 ;
633640 locks += n_tiles;
634641 slice_col = 0 ;
642+ par_id++;
635643 }
636644 };
637645 init_slice ();
@@ -1321,7 +1329,7 @@ __global__ void Marlin(
13211329 // finally have to globally reduce over the results. As the striped
13221330 // partitioning minimizes the number of such reductions and our outputs are
13231331 // usually rather small, we perform this reduction serially in L2 cache.
1324- auto global_reduce = [&](bool first = false , bool last = false ) {
1332+ auto global_reduce_fp16 = [&](bool first = false , bool last = false ) {
13251333 // We are very careful here to reduce directly in the output buffer to
13261334 // maximize L2 cache utilization in this step. To do this, we write out
13271335 // results in FP16 (but still reduce with FP32 compute).
@@ -1382,6 +1390,53 @@ __global__ void Marlin(
13821390 }
13831391 };
13841392
1393+ // Globally reduce over threadblocks that compute the same column block.
1394+ // We use a tmp C buffer to reduce in full fp32 precision.
1395+ auto global_reduce_fp32 = [&](bool first = false , bool last = false ) {
1396+ constexpr int tb_m = thread_m_blocks * 16 ;
1397+ constexpr int tb_n = thread_n_blocks * 16 ;
1398+
1399+ constexpr int c_size = tb_m * tb_n * sizeof (float ) / 16 ;
1400+
1401+ constexpr int active_threads = 32 * thread_n_blocks / 4 ;
1402+ bool is_th_active = threadIdx .x < active_threads;
1403+
1404+ int par_offset = c_size * n_tiles * par_id;
1405+ int slice_offset = c_size * slice_col;
1406+
1407+ constexpr int num_floats = thread_m_blocks * 4 * 2 * 4 ;
1408+ constexpr int th_size = num_floats * sizeof (float ) / 16 ;
1409+
1410+ int c_cur_offset = par_offset + slice_offset;
1411+
1412+ if (!is_th_active) {
1413+ return ;
1414+ }
1415+
1416+ if (!first) {
1417+ float * frag_c_ptr = reinterpret_cast <float *>(&frag_c);
1418+ #pragma unroll
1419+ for (int k = 0 ; k < th_size; k++) {
1420+ sh[threadIdx .x ] =
1421+ C_tmp[c_cur_offset + active_threads * k + threadIdx .x ];
1422+
1423+ float * sh_c_ptr = reinterpret_cast <float *>(&sh[threadIdx .x ]);
1424+ #pragma unroll
1425+ for (int f = 0 ; f < 4 ; f++) {
1426+ frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
1427+ }
1428+ }
1429+ }
1430+
1431+ if (!last) {
1432+ int4 * frag_c_ptr = reinterpret_cast <int4 *>(&frag_c);
1433+ #pragma unroll
1434+ for (int k = 0 ; k < th_size; k++) {
1435+ C_tmp[c_cur_offset + active_threads * k + threadIdx .x ] = frag_c_ptr[k];
1436+ }
1437+ }
1438+ };
1439+
13851440 // Write out the reduce final result in the correct layout. We only actually
13861441 // reshuffle matrix fragments in this step, the reduction above is performed
13871442 // in fragment layout.
@@ -1606,7 +1661,11 @@ __global__ void Marlin(
16061661 if (slice_count > 1 ) { // only globally reduce if there is more than one
16071662 // block in a slice
16081663 barrier_acquire (&locks[slice_col], slice_idx);
1609- global_reduce (slice_idx == 0 , last);
1664+ if (use_fp32_reduce) {
1665+ global_reduce_fp32 (slice_idx == 0 , last);
1666+ } else {
1667+ global_reduce_fp16 (slice_idx == 0 , last);
1668+ }
16101669 barrier_release (&locks[slice_col], last);
16111670 }
16121671 if (last) // only the last block in a slice actually writes the result
@@ -1661,8 +1720,8 @@ __global__ void Marlin(
16611720 THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
16621721 HAS_ZP, GROUP_BLOCKS> \
16631722 <<<blocks, NUM_THREADS, max_shared_mem, stream>>> ( \
1664- A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
1665- prob_m, prob_n, prob_k, locks); \
1723+ A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
1724+ num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
16661725 }
16671726
16681727typedef struct {
@@ -1801,6 +1860,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
18011860 return true ;
18021861}
18031862
1863+ int determine_reduce_max_m (int prob_m, int max_par) {
1864+ constexpr int tile_m_size = 16 ;
1865+
1866+ if (prob_m <= tile_m_size) {
1867+ return tile_m_size;
1868+
1869+ } else if (prob_m <= tile_m_size * 2 ) {
1870+ return tile_m_size * 2 ;
1871+
1872+ } else if (prob_m <= tile_m_size * 3 ) {
1873+ return tile_m_size * 3 ;
1874+
1875+ } else if (prob_m <= tile_m_size * 4 ) {
1876+ return tile_m_size * 4 ;
1877+
1878+ } else {
1879+ int cur_par = min (div_ceil (prob_m, tile_m_size * 4 ), max_par);
1880+ return tile_m_size * 4 * cur_par;
1881+ }
1882+ }
1883+
18041884exec_config_t determine_thread_config (int prob_m, int prob_n, int prob_k,
18051885 int num_bits, int group_size,
18061886 bool has_act_order, bool is_k_full,
@@ -1880,13 +1960,13 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
18801960 __CALL_IF(NUM_BITS, 4 , N_BLOCKS, K_BLOCKS, false , true , 8 , NUM_THREADS)
18811961
18821962template <typename scalar_t>
1883- void marlin_mm_f16i4(const void * A, const void * B, void * C, void * s, void * zp ,
1884- void * g_idx , void * perm , void * a_tmp, int prob_m ,
1885- int prob_n, int prob_k, void * workspace, int num_bits ,
1886- bool has_act_order , bool is_k_full , bool has_zp ,
1887- int num_groups, int group_size, int dev,
1963+ void marlin_mm_f16i4(const void * A, const void * B, void * C, void * C_tmp ,
1964+ void * s , void * zp , void * g_idx, void * perm, void * a_tmp ,
1965+ int prob_m, int prob_n, int prob_k, void * workspace,
1966+ int num_bits , bool has_act_order , bool is_k_full ,
1967+ bool has_zp, int num_groups, int group_size, int dev,
18881968 cudaStream_t stream, int thread_k, int thread_n, int sms,
1889- int max_par) {
1969+ int max_par, bool use_fp32_reduce ) {
18901970 TORCH_CHECK (num_bits == 4 || num_bits == 8 ,
18911971 " num_bits must be 4 or 8. Got = " , num_bits);
18921972 TORCH_CHECK (prob_m > 0 && prob_n > 0 && prob_k > 0 , " Invalid MNK = [" , prob_m,
@@ -1970,6 +2050,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
19702050 const int4 * A_ptr = (const int4 *)A;
19712051 const int4 * B_ptr = (const int4 *)B;
19722052 int4 * C_ptr = (int4 *)C;
2053+ int4 * C_tmp_ptr = (int4 *)C_tmp;
19732054 const int4 * s_ptr = (const int4 *)s;
19742055 const int4 * zp_ptr = (const int4 *)zp;
19752056 const int * g_idx_ptr = (const int *)g_idx;
@@ -2049,7 +2130,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
20492130 torch::Tensor& g_idx, torch::Tensor& perm,
20502131 torch::Tensor& workspace, int64_t num_bits,
20512132 int64_t size_m, int64_t size_n, int64_t size_k,
2052- bool is_k_full, bool has_zp) {
2133+ bool is_k_full, bool has_zp,
2134+ bool use_fp32_reduce) {
20532135 // Verify num_bits
20542136 TORCH_CHECK (num_bits == 4 || num_bits == 8 ,
20552137 " num_bits must be 4 or 8. Got = " , num_bits);
@@ -2099,6 +2181,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
20992181 torch::Tensor c = torch::empty ({size_m, size_n}, options);
21002182 torch::Tensor a_tmp = torch::empty ({size_m, size_k}, options);
21012183
2184+ // Alloc C tmp buffer that is going to be used for the global reduce
2185+ int reduce_max_m = marlin::determine_reduce_max_m (size_m, marlin::max_par);
2186+ int reduce_n = size_n;
2187+ auto options_fp32 =
2188+ torch::TensorOptions ().dtype (at::kFloat ).device (a.device ());
2189+ if (!use_fp32_reduce) {
2190+ reduce_max_m = 0 ;
2191+ reduce_n = 0 ;
2192+ }
2193+ torch::Tensor c_tmp = torch::empty ({reduce_max_m, reduce_n}, options_fp32);
2194+
21022195 // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
21032196 // auto -1)
21042197 int thread_k = -1 ;
@@ -2171,20 +2264,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
21712264 if (a.scalar_type () == at::ScalarType::Half) {
21722265 marlin::marlin_mm_f16i4<half>(
21732266 a.data_ptr <at::Half>(), b_q_weight.data_ptr (), c.data_ptr <at::Half>(),
2174- b_scales.data_ptr <at::Half>(), b_zeros.data_ptr (), g_idx.data_ptr (),
2175- perm.data_ptr (), a_tmp.data_ptr <at::Half>(), size_m, size_n, size_k,
2267+ c_tmp.data_ptr <float >(), b_scales.data_ptr <at::Half>(),
2268+ b_zeros.data_ptr (), g_idx.data_ptr (), perm.data_ptr (),
2269+ a_tmp.data_ptr <at::Half>(), size_m, size_n, size_k,
21762270 workspace.data_ptr (), num_bits, has_act_order, is_k_full, has_zp,
21772271 num_groups, group_size, dev, at::cuda::getCurrentCUDAStream (dev),
2178- thread_k, thread_n, sms, marlin::max_par);
2272+ thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce );
21792273 } else if (a.scalar_type () == at::ScalarType::BFloat16) {
21802274 marlin::marlin_mm_f16i4<nv_bfloat16>(
21812275 a.data_ptr <at::BFloat16>(), b_q_weight.data_ptr (),
2182- c.data_ptr <at::BFloat16>(), b_scales .data_ptr <at::BFloat16 >(),
2183- b_zeros .data_ptr (), g_idx .data_ptr (), perm .data_ptr (),
2184- a_tmp.data_ptr <at::BFloat16>(), size_m, size_n, size_k,
2276+ c.data_ptr <at::BFloat16>(), c_tmp .data_ptr <float >(),
2277+ b_scales .data_ptr <at::BFloat16> (), b_zeros .data_ptr (), g_idx .data_ptr (),
2278+ perm. data_ptr (), a_tmp.data_ptr <at::BFloat16>(), size_m, size_n, size_k,
21852279 workspace.data_ptr (), num_bits, has_act_order, is_k_full, has_zp,
21862280 num_groups, group_size, dev, at::cuda::getCurrentCUDAStream (dev),
2187- thread_k, thread_n, sms, marlin::max_par);
2281+ thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce );
21882282 } else {
21892283 TORCH_CHECK (false , " gpt_marlin_gemm only supports bfloat16 and float16" );
21902284 }
0 commit comments