Skip to content

Commit 10c8823

Browse files
committed
Fix ROCm global load inline assembly in Marlin sparse kernel
Modify the cp_async4 functions to use the correct extern declaration for __builtin_amdgcn_global_load_lds on ROCm platforms. This ensures proper inline assembly and cross-platform compatibility for the Marlin sparse kernel's memory loading operations.
1 parent 8b00b06 commit 10c8823

File tree

1 file changed

+3
-3
lines changed
  • torchao/csrc/cuda/sparse_marlin

1 file changed

+3
-3
lines changed

torchao/csrc/cuda/sparse_marlin/mem.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
5151
int src_in_bytes = (zfill ? 0 : BYTES);
5252
uint32_t smem = cvta_to_shared(smem_ptr);
5353
#ifdef USE_ROCM
54-
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
54+
extern __builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
5555
#else
5656
asm volatile(
5757
"{\n"
@@ -68,7 +68,7 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
6868
const int BYTES = 16;
6969
uint32_t smem = cvta_to_shared(smem_ptr);
7070
#ifdef USE_ROCM
71-
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
71+
extern __builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
7272
#else
7373
asm volatile(
7474
"{\n"
@@ -85,7 +85,7 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
8585
const int BYTES = 16;
8686
uint32_t smem = cvta_to_shared(smem_ptr);
8787
#ifdef USE_ROCM
88-
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
88+
extern __builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
8989
#else
9090
asm volatile(
9191
"{\n"

0 commit comments

Comments
 (0)