Skip to content

[wip] Rocm sparse fix #1868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3a77641
Fix ROCm GPU architecture detection in setup.py
petrex Mar 11, 2025
76d68bf
Refactor CUDA and ROCm source file handling in setup.py
petrex Mar 11, 2025
16d22c1
Improve CUTLASS kernel support detection for non-Windows platforms
petrex Mar 11, 2025
7481959
Reorder source file collection in setup.py
petrex Mar 11, 2025
94d1fb4
Remove redundant NVCC compilation flag in setup.py
petrex Mar 11, 2025
72c2642
Add ROCm-specific inline assembly for sparse Marlin MMA operations
petrex Mar 11, 2025
75f4787
Fix ROCm half-precision conversion in sparse Marlin MMA
petrex Mar 11, 2025
cf79039
Optimize half-precision operations in sparse Marlin MMA
petrex Mar 11, 2025
a98a427
Optimize ROCm half-precision operations in sparse Marlin MMA
petrex Mar 11, 2025
30bd924
Fix ROCm float multiplication in sparse Marlin MMA
petrex Mar 11, 2025
66691c3
Add ROCm header support for sparse Marlin MMA implementation
petrex Mar 11, 2025
04014e7
Update ROCm float multiplication in sparse Marlin MMA
petrex Mar 11, 2025
dc53980
Optimize ROCm global to LDS transfer in sparse Marlin MMA
petrex Mar 11, 2025
6f43e01
Simplify ROCm float multiplication in sparse Marlin MMA
petrex Mar 11, 2025
ed9282d
Fix CUDA kernel attribute setting in Marlin sparse MMA implementation
petrex Mar 11, 2025
b539062
Enhance ROCm global to LDS transfer with size-specific load instructions
petrex Mar 11, 2025
c316a98
Fix missing closing braces in ROCm cp_async4 memory transfer functions
petrex Mar 11, 2025
3a2481f
Remove unnecessary fallback memcpy in ROCm cp_async4 memory transfer …
petrex Mar 11, 2025
59455ed
Remove 16-byte ds_load instruction in ROCm cp_async4 memory transfer …
petrex Mar 11, 2025
ca6c646
global_load_dwordx4
petrex Mar 11, 2025
d1a9df9
Improve ROCm memory load instructions in sparse Marlin MMA implementa…
petrex Mar 11, 2025
63e8d5e
Refine ROCm memory load instruction in sparse Marlin ldsm4_m function
petrex Mar 11, 2025
3e5a411
Update ROCm MFMA instruction syntax in sparse Marlin MMA implementation
petrex Mar 11, 2025
479cc1d
Merge branch 'main' into rocm_sparse_fix
petrex Apr 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def __init__(self):
default=(self._is_arm64() and self._is_macos()),
)
if self.build_cpu_aarch64:
assert (
self._is_arm64()
), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
assert self._is_arm64(), (
"TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
)

# TORCHAO_BUILD_KLEIDIAI is disabled by default for now because
# 1) It increases the build time
Expand All @@ -108,9 +108,9 @@ def __init__(self):
"TORCHAO_BUILD_KLEIDIAI", default=False
)
if self.build_kleidi_ai:
assert (
self.build_cpu_aarch64
), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
assert self.build_cpu_aarch64, (
"TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
)

# TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default.
self.build_experimental_mps = self._os_bool_var(
Expand All @@ -119,9 +119,9 @@ def __init__(self):
if self.build_experimental_mps:
assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS"
assert self._is_arm64(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64"
assert (
torch.mps.is_available()
), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
assert torch.mps.is_available(), (
"TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
)

def _is_arm64(self) -> bool:
return platform.machine().startswith("arm64")
Expand Down Expand Up @@ -338,6 +338,7 @@ def get_extensions():
hip_sources = list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)

extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
hip_sources += list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
Expand All @@ -350,7 +351,7 @@ def get_extensions():
# TOOD: Remove this and use what CUDA has once we fix all the builds.
if IS_ROCM and use_cuda:
# Add ROCm GPU architecture check
gpu_arch = torch.cuda.get_device_properties(0).name
gpu_arch = torch.cuda.get_device_properties(0).name.gcnArchName
if gpu_arch != "gfx942":
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
print(
Expand Down
4 changes: 2 additions & 2 deletions torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -867,8 +867,8 @@ __global__ void Marlin_24(
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS) { \
cudaFuncSetAttribute( \
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
reinterpret_cast<const void*>(&Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>), \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
Expand Down
78 changes: 64 additions & 14 deletions torchao/csrc/cuda/sparse_marlin/mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,18 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
int src_in_bytes = (zfill ? 0 : BYTES);
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
// Use appropriate ds_load instruction based on byte size
if (BYTES == 4) {
asm volatile(
"{\n"
" ds_load_b32 %0, %1\n"
"}\n" :: "v"(smem), "v"(glob_ptr));
} else if (BYTES == 8) {
asm volatile(
"{\n"
" ds_load_b64 %0, %1\n"
"}\n" :: "v"(smem), "v"(glob_ptr));
}
#else
asm volatile(
"{\n"
Expand All @@ -73,7 +84,18 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
const int BYTES = 16;
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
// Use appropriate ds_load instruction based on byte size
if (BYTES == 4) {
asm volatile(
"{\n"
" ds_load_b32 %0, %1\n"
"}\n" :: "v"(smem), "v"(glob_ptr));
} else if (BYTES == 8) {
asm volatile(
"{\n"
" ds_load_b64 %0, %1\n"
"}\n" :: "v"(smem), "v"(glob_ptr));
}
#else
asm volatile(
"{\n"
Expand All @@ -90,7 +112,18 @@ __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
// Use appropriate ds_load instruction based on byte size
if (BYTES == 4) {
asm volatile(
"{\n"
" ds_load_b32 %0, %1\n"
"}\n" :: "v"(smem), "v"(glob_ptr));
} else if (BYTES == 8) {
asm volatile(
"{\n"
" ds_load_b64 %0, %1\n"
"}\n" :: "v"(smem), "v"(glob_ptr));
}
#else
asm volatile(
"{\n"
Expand Down Expand Up @@ -128,11 +161,19 @@ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
asm volatile(
"ds_read_b128 %0, %1 offset:0\n"
"ds_read_b128 %2, %1 offset:16\n"
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3])
: "v"(smem));
// Try using multiple ds_read_b32 instructions which are more widely supported
asm volatile(
"ds_read_b32 %0, %8 offset:0\n"
"ds_read_b32 %1, %8 offset:4\n"
"ds_read_b32 %2, %8 offset:8\n"
"ds_read_b32 %3, %8 offset:12\n"
"ds_read_b32 %4, %8 offset:16\n"
"ds_read_b32 %5, %8 offset:20\n"
"ds_read_b32 %6, %8 offset:24\n"
"ds_read_b32 %7, %8 offset:28\n"
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]),
"=v"(a[4]), "=v"(a[5]), "=v"(a[6]), "=v"(a[7])
: "v"(smem));
#else
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
Expand All @@ -145,7 +186,8 @@ __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
asm volatile(
"ds_read_b64 %0, %2 offset:0\n"
"ds_read_b32 %0, %2 offset:0\n"
"ds_read_b32 %1, %2 offset:4\n"
: "=v"(a[0]), "=v"(a[1])
: "v"(smem));
#else
Expand All @@ -161,11 +203,19 @@ __device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = cvta_to_shared(smem_ptr);
#ifdef USE_ROCM
asm volatile(
"ds_read_b128 %0, %1 offset:0\n"
"ds_read_b128 %2, %1 offset:16\n"
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3])
: "v"(smem));
// Try using multiple ds_read_b32 instructions which are more widely supported
asm volatile(
"ds_read_b32 %0, %8 offset:0\n"
"ds_read_b32 %1, %8 offset:4\n"
"ds_read_b32 %2, %8 offset:8\n"
"ds_read_b32 %3, %8 offset:12\n"
"ds_read_b32 %4, %8 offset:16\n"
"ds_read_b32 %5, %8 offset:20\n"
"ds_read_b32 %6, %8 offset:24\n"
"ds_read_b32 %7, %8 offset:28\n"
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]),
"=v"(a[4]), "=v"(a[5]), "=v"(a[6]), "=v"(a[7])
: "v"(smem));
#else
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
Expand Down
83 changes: 62 additions & 21 deletions torchao/csrc/cuda/sparse_marlin/mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,26 @@
#include <cudaTypedefs.h>
#endif

#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <device_functions.h> // For some ROCm versions
// Some intrinsics might require the compiler to be in the right mode
// with the correct target architecture flags (-march=gfx942)
#endif

namespace torchao {

// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
// is not supported. On later versions of CUDA the version without ordered
// metadata results in the following warning:
// | Advisory: Modifier ‘.sp::ordered_metadata should be used on instruction
// | mma instead of modifier ‘.sp’ as it is expected to have substantially
// | Advisory: Modifier 'sp::ordered_metadata' should be used on instruction
// | 'mma' instead of modifier 'sp' as it is expected to have substantially
// | reduced performance on some future architectures

#if defined(USE_ROCM)
// HIP ISA doesn't have an equivalent for ordered_metadata, so we'll use the standard mma instruction
#define MMA_SP_INST "v_mfma_f32_16x16x16f16 "
// Correct MFMA instruction for AMD GPUs
#define MMA_SP_INST "v_mfma_f32_16x16x16_f16 "
#elif defined(CUDA_VERSION) && CUDA_VERSION >= 12050
#define MMA_SP_INST \
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
Expand All @@ -58,6 +66,23 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,

float* c = reinterpret_cast<float*>(&frag_c);
if (psel == 0) {
#ifdef USE_ROCM
// AMD GPUs use a different syntax for MFMA instructions
// The operands need to be listed individually, not in curly braces
asm volatile(MMA_SP_INST
"%0, %4, %8, %12\n"
: "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3])
: "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]),
"v"(b[0]), "v"(b[2]), "v"(b[4]), "v"(b[6]),
"v"(c[0]), "v"(c[1]), "v"(c[2]), "v"(c[3]));

asm volatile(MMA_SP_INST
"%0, %4, %8, %12\n"
: "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7])
: "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]),
"v"(b[1]), "v"(b[3]), "v"(b[5]), "v"(b[7]),
"v"(c[4]), "v"(c[5]), "v"(c[6]), "v"(c[7]));
#else
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
Expand All @@ -72,7 +97,22 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
"f"(c[6]), "f"(c[7]), "r"(e[0]));
#endif
} else {
#ifdef USE_ROCM
asm volatile(MMA_SP_INST
"%0, %4, %8, %12\n"
: "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3])
: "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]),
"v"(b[0]), "v"(b[2]), "v"(b[4]), "v"(b[6]),
"v"(c[0]), "v"(c[1]), "v"(c[2]), "v"(c[3]));
asm volatile(MMA_SP_INST
"%0, %4, %8, %12\n"
: "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7])
: "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]),
"v"(b[1]), "v"(b[3]), "v"(b[5]), "v"(b[7]),
"v"(c[4]), "v"(c[5]), "v"(c[6]), "v"(c[7]));
#else
asm volatile(MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
Expand All @@ -87,6 +127,7 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
"f"(c[6]), "f"(c[7]), "r"(e[0]));
#endif
}
}

Expand Down Expand Up @@ -114,8 +155,8 @@ __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
uint2 r;
#ifdef USE_ROCM
// AMD implementation
r.x = __builtin_amdgcn_cvt_pkrtz(c0, c1);
r.y = __builtin_amdgcn_cvt_pkrtz(c2, c3);
r.x = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(c0, c1));
r.y = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(c2, c3));
#else
// NVIDIA implementation
asm("{\n\t"
Expand Down Expand Up @@ -177,8 +218,8 @@ __device__ inline FragB dequant_4bit(int q) {
const __half2* MUL_ptr = reinterpret_cast<const __half2*>(&MUL);
const __half2* ADD_ptr = reinterpret_cast<const __half2*>(&ADD);

frag_b[0] = __hsub(*lo_ptr, *SUB_ptr);
frag_b[1] = __hfma(*hi_ptr, *MUL_ptr, *ADD_ptr);
frag_b[0] = __hsub2(*lo_ptr, *SUB_ptr);
frag_b[1] = __hfma2(*hi_ptr, *MUL_ptr, *ADD_ptr);
#else
// NVIDIA implementation
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
Expand Down Expand Up @@ -211,8 +252,8 @@ __device__ inline FragB dequant_8bit(int q) {
__half2* hi_ptr = reinterpret_cast<__half2*>(&hi);
const __half2* magic_num_ptr = reinterpret_cast<const __half2*>(&I8s_TO_F16s_MAGIC_NUM);

frag_b[0] = __hsub(*lo_ptr, *magic_num_ptr);
frag_b[1] = __hsub(*hi_ptr, *magic_num_ptr);
frag_b[0] = __hsub2(*lo_ptr, *magic_num_ptr);
frag_b[1] = __hsub2(*hi_ptr, *magic_num_ptr);
#else
// NVIDIA implementation
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
Expand All @@ -229,8 +270,8 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
#ifdef USE_ROCM
// AMD implementation
__half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
frag_b[0] = __hmul(frag_b[0], s);
frag_b[1] = __hmul(frag_b[1], s);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
#else
// NVIDIA implementation
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
Expand All @@ -243,16 +284,16 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
FragS& s0, float* c4, float* c5, float* c6,
float* c7, FragS& s1) {
#ifdef USE_ROCM
// AMD implementation
*c0 = __builtin_amdgcn_fmul_legacy(*c0, __half2float(s0[0].x));
*c1 = __builtin_amdgcn_fmul_legacy(*c1, __half2float(s0[0].y));
*c2 = __builtin_amdgcn_fmul_legacy(*c2, __half2float(s0[1].x));
*c3 = __builtin_amdgcn_fmul_legacy(*c3, __half2float(s0[1].y));
// AMD MI300X implementation
*c0 = *c0 * __half2float(s0[0].x);
*c1 = *c1 * __half2float(s0[0].y);
*c2 = *c2 * __half2float(s0[1].x);
*c3 = *c3 * __half2float(s0[1].y);

*c4 = __builtin_amdgcn_fmul_legacy(*c4, __half2float(s1[0].x));
*c5 = __builtin_amdgcn_fmul_legacy(*c5, __half2float(s1[0].y));
*c6 = __builtin_amdgcn_fmul_legacy(*c6, __half2float(s1[1].x));
*c7 = __builtin_amdgcn_fmul_legacy(*c7, __half2float(s1[1].y));
*c4 = *c4 * __half2float(s1[0].x);
*c5 = *c5 * __half2float(s1[0].y);
*c6 = *c6 * __half2float(s1[1].x);
*c7 = *c7 * __half2float(s1[1].y);
#else
// NVIDIA implementation
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
Expand Down
Loading