From 84a8ea18c406f21de8ff3a49321784980581b6bd Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 16 May 2024 10:24:34 -0400 Subject: [PATCH] [Bugfix] Fix marlin 2:4 kernel crash on H100 (#243) The reason for the crash was the inline PTX assembly that introduced the async_copy with streaming behavior. The solution is to use the more standard PTX for async_copy (without the fractional L2 policy for "evict_first"). There is no performance difference between standard async_copy PTX and the previous one. Ported from dense marlin: https://github.com/vllm-project/vllm/pull/4218/ --- csrc/quantization/marlin/sparse/common/mem.h | 16 +++++----------- .../marlin/sparse/marlin_24_cuda_kernel.cu | 10 +++++----- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/csrc/quantization/marlin/sparse/common/mem.h b/csrc/quantization/marlin/sparse/common/mem.h index 7cd367e76c8fd..8134f0f8e56d8 100644 --- a/csrc/quantization/marlin/sparse/common/mem.h +++ b/csrc/quantization/marlin/sparse/common/mem.h @@ -45,19 +45,13 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool ); } -// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for -// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need -// for inputs A and outputs C. -__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { +// Asynchronous global->shared copy +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .b64 p;\n" - " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" - " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" - "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) - ); + asm volatile("{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 3a13d485b4199..42b0566183a8d 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -392,7 +392,7 @@ __global__ void Marlin_24( for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < b_thread_vecs; j++) { - cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } B_ptr[i] += b_gl_rd_delta_o; @@ -401,7 +401,7 @@ __global__ void Marlin_24( #pragma unroll for (int i = 0; i < m_sh_iters; i++) { if (m_sh_wr_pred) - cp_async4_stream(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], + cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); meta_ptr[i] += m_gl_rd_delta_o; } @@ -409,7 +409,7 @@ __global__ void Marlin_24( if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { int4 *sh_s_stage = sh_s + s_sh_stage * pipe; if (s_sh_wr_pred) - cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -763,12 +763,12 @@ __global__ void Marlin_24( if constexpr (group_blocks == -1) { if constexpr (num_bits == 8) { if (s_sh_wr_pred) - cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } else { if (last) { if (s_sh_wr_pred) - cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } }