Skip to content

Commit 0e38322

Browse files
alexm-redhatalexeykondrat
authored andcommitted
[Bugfix] Fix marlin kernel crash on H100 (vllm-project#4218)
This PR addresses the Marlin kernel H100 crash that was reported here: neuralmagic#187. The reason for the crash was the inline PTX assembly that introduced the async_copy with streaming behavior. The solution is to use the more standard PTX for async_copy (without the fractional L2 policy for "evict_first"). There is no performance difference between standard async_copy PTX and the previous one.
1 parent f59920c commit 0e38322

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

csrc/quantization/marlin/marlin_cuda_kernel.cu

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,13 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
6767
"r"(smem), "l"(glob_ptr), "n"(BYTES));
6868
}
6969

70-
// Asynchronous global->shared copy with a cache hint indicating that the values
71-
// may be evicted immediately; used for quantized weights B, which are only
72-
// accessed precisely once and should thus not pollute the L2 cache which we
73-
// need for inputs A and outputs C.
74-
__device__ inline void cp_async4_stream(void *smem_ptr, const void *glob_ptr) {
70+
// Asynchronous global->shared copy
71+
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
7572
const int BYTES = 16;
7673
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
77-
asm volatile(
78-
"{\n"
79-
" .reg .b64 p;\n"
80-
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
81-
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
82-
"}\n" ::"r"(smem),
83-
"l"(glob_ptr), "n"(BYTES));
74+
asm volatile("{\n"
75+
" cp.async.cg.shared.global [%0], [%1], %2;\n"
76+
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
8477
}
8578

8679
// Async copy fence.
@@ -448,14 +441,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
448441
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
449442
#pragma unroll
450443
for (int i = 0; i < b_sh_wr_iters; i++) {
451-
cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
444+
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
452445
B_ptr[i] += b_gl_rd_delta_o;
453446
}
454447
// Only fetch scales if this tile starts a new group
455448
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
456449
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
457450
if (s_sh_wr_pred)
458-
cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
451+
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
459452
s_gl_rd += s_gl_rd_delta;
460453
}
461454
}
@@ -750,7 +743,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
750743
// write-out
751744
if (group_blocks == -1 && last) {
752745
if (s_sh_wr_pred)
753-
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
746+
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
754747
cp_async_fence();
755748
}
756749
thread_block_reduce();

0 commit comments

Comments
 (0)