Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) {
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

#if defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_WMMA)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#define VOLTA_MMA_AVAILABLE
Expand Down Expand Up @@ -283,6 +287,14 @@ static bool amd_mfma_available(const int cc) {
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
}

static bool amd_wmma_available(const int cc) {
#if !defined(GGML_HIP_NO_WMMA)
return GGML_CUDA_CC_IS_RDNA4(cc);
#else
return false;
#endif //!defined(GGML_HIP_NO_WMMA)
}

static bool volta_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
}
Expand Down
130 changes: 130 additions & 0 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,31 @@ namespace ggml_cuda_mma {
static constexpr int J = J_;

#if defined(GGML_USE_HIP)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#else
static constexpr int ne = I * J / 64;
T x[ne] = {0};

Expand Down Expand Up @@ -119,6 +144,7 @@ namespace ggml_cuda_mma {
return -1;
}
}
#endif // defined(RDNA4)
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I * J / 32;
T x[ne] = {0};
Expand Down Expand Up @@ -236,6 +262,32 @@ namespace ggml_cuda_mma {
return -1;
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
half2 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#endif // defined(RDNA4)
#else
static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};
Expand Down Expand Up @@ -285,6 +337,34 @@ namespace ggml_cuda_mma {
struct tile<I_, J_, nv_bfloat162> {
static constexpr int I = I_;
static constexpr int J = J_;

#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#endif // defined(RDNA4)
#else
static constexpr int ne = I * J / WARP_SIZE;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

Expand Down Expand Up @@ -320,6 +400,7 @@ namespace ggml_cuda_mma {
return -1;
}
}
#endif // defined(AMD_WMMA_AVAILABLE)
};

template <int I, int J>
Expand Down Expand Up @@ -353,6 +434,19 @@ namespace ggml_cuda_mma {
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
// Special tile size to load <16, 8> as <16, 16> for half2 and __hip_bfloat162
if constexpr (I == 16 && J == 8 && (std::is_same<T, half2>::value || std::is_same<T, nv_bfloat162>::value)) {
constexpr int RDNA4_WMMA_MEM_N = 4;
using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) int32_t;
reinterpret_cast<TxN_t&>(t.x[0]) = reinterpret_cast<const TxN_t&>(xs0[t.get_i(0) * stride + t.get_j(0)]);
} else {
constexpr int RDNA4_WMMA_MEM_N = 8;
using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) T;
reinterpret_cast<TxN_t&>(t.x[0]) = reinterpret_cast<const TxN_t&>(xs0[t.get_i(0) * stride + t.get_j(0)]);
}
#endif // defined(RDNA4)
#else
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
Expand Down Expand Up @@ -639,12 +733,48 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // TURING_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
#if defined(AMD_MFMA_AVAILABLE)
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-cuda/mmf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
} else {
if (src1_ncols > 16) {
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
return false;
}
}
Expand All @@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
case GGML_TYPE_F32:
return ampere_mma_available(cc);
case GGML_TYPE_F16:
return volta_mma_available(cc) || turing_mma_available(cc);
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc);
return ampere_mma_available(cc) || amd_wmma_available(cc);
default:
return false;
}
Expand Down
Loading