Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) {
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

#if defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_WMMA)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define TURING_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
Expand Down Expand Up @@ -278,6 +282,14 @@ static bool amd_mfma_available(const int cc) {
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
}

static bool amd_wmma_available(const int cc) {
#if !defined(GGML_HIP_NO_WMMA)
return GGML_CUDA_CC_IS_RDNA4(cc);
#else
return false;
#endif //!defined(AMD_WMMA_AVAILABLE)
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool turing_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
Expand Down
141 changes: 141 additions & 0 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,30 @@ namespace ggml_cuda_mma {
static constexpr int J = J_;

#if defined(GGML_USE_HIP)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else if constexpr (I == 16 && J == 8) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
return 4 * (threadIdx.x / 16) + l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else if constexpr (I == 16 && J == 8) { // dummy shape to handle TF32, just make the compiler happle, don't use it in the real case
return threadIdx.x % 16;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#else
static constexpr int ne = I * J / 64;
T x[ne] = {0};

Expand Down Expand Up @@ -104,6 +128,7 @@ namespace ggml_cuda_mma {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#endif // defined(RDNA4)
#else
static constexpr int ne = I * J / 32;
T x[ne] = {0};
Expand Down Expand Up @@ -140,6 +165,29 @@ namespace ggml_cuda_mma {
struct tile<I_, J_, half2> {
static constexpr int I = I_;
static constexpr int J = J_;

#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
half2 x[ne] = {{0.0f, 0.0f}};

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#endif // defined(RDNA4)
#else
static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};

Expand All @@ -166,12 +214,36 @@ namespace ggml_cuda_mma {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#endif // defined(GGML_USE_HIP)
};

template <int I_, int J_>
struct tile<I_, J_, nv_bfloat162> {
static constexpr int I = I_;
static constexpr int J = J_;

#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#endif // defined(RDNA4)
#else
static constexpr int ne = I * J / WARP_SIZE;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

Expand All @@ -198,6 +270,7 @@ namespace ggml_cuda_mma {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#endif // defined(AMD_WMMA_AVAILABLE)
};

template <int I, int J>
Expand Down Expand Up @@ -231,6 +304,19 @@ namespace ggml_cuda_mma {
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
// Special tile size to load <16, 8> as <16, 16> for half2 and __hip_bfloat162
if constexpr (I == 16 && J == 8 && (std::is_same<T, half2>::value || std::is_same<T, nv_bfloat162>::value)) {
constexpr int RDNA4_WMMA_MEM_N = 4;
using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) int;
reinterpret_cast<TxN_t&>(t.x[0]) = reinterpret_cast<const TxN_t&>(xs0[t.get_i(0) * stride + t.get_j(0)]);
} else {
constexpr int RDNA4_WMMA_MEM_N = 8;
using TxN_t = __attribute__((ext_vector_type(RDNA4_WMMA_MEM_N))) T;
reinterpret_cast<TxN_t&>(t.x[0]) = reinterpret_cast<const TxN_t&>(xs0[t.get_i(0) * stride + t.get_j(0)]);
}
#endif // defined(RDNA4)
#else
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
Expand Down Expand Up @@ -461,6 +547,25 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, float> & A, const tile<16, 8, float> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, "
"%2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef TURING_MMA_AVAILABLE
Expand Down Expand Up @@ -489,12 +594,48 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // TURING_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
#if defined(AMD_MFMA_AVAILABLE)
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/mmf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
case GGML_TYPE_F32:
return ampere_mma_available(cc);
case GGML_TYPE_F16:
return turing_mma_available(cc);
return turing_mma_available(cc) || amd_wmma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc);
return ampere_mma_available(cc) || amd_wmma_available(cc);
default:
return false;
}
Expand Down
55 changes: 51 additions & 4 deletions ggml/src/ggml-cuda/mmf.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@ static __global__ void mul_mat_f(
const int stride_col_id, const int stride_row_id,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if defined(AMD_WMMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, T> tile_A;
typedef tile<16, 8, T> tile_B;
typedef tile<16, 16, float> tile_C;
#else
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
#endif // defined(AMD_MFMA_AVAILABLE)

constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
Expand Down Expand Up @@ -151,11 +157,31 @@ static __global__ void mul_mat_f(

if constexpr (!has_ids) {
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
#if !defined(GGML_USE_HIP)
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
#else
if constexpr (std::is_same<T, half2>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp);
} else if constexpr (std::is_same<T, nv_bfloat162>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp);
} else {
static_assert(0, "unsupported type");
}
#endif // !defined(GGML_USE_HIP)
} else {
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
#if !defined(GGML_USE_HIP)
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
#else
if constexpr (std::is_same<T, half2>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp);
} else if constexpr (std::is_same<T, nv_bfloat162>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp);
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#endif // !defined(GGML_USE_HIP)
}
}
} else {
Expand Down Expand Up @@ -229,7 +255,7 @@ static __global__ void mul_mat_f(
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#endif // defined(AMD_WMMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}


Expand All @@ -244,10 +270,16 @@ static __global__ void mul_mat_f_ids(
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if defined(AMD_WMMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, T> tile_A;
typedef tile<16, 8, T> tile_B;
typedef tile<16, 16, float> tile_C;
#else
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
#endif // defined(AMD_MFMA_AVAILABLE)

constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
Expand Down Expand Up @@ -389,7 +421,17 @@ static __global__ void mul_mat_f_ids(
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const float2 tmp = vals_buf[curr_buf][j0];
#if !defined(GGML_USE_HIP)
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
#else
if constexpr (std::is_same<T, half2>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22half2_rn(tmp);
} else if constexpr (std::is_same<T, nv_bfloat162>::value) {
tile_xy[j0*tile_k_padded + threadIdx.x] = __float22bfloat162_rn(tmp);
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#endif // !defined(GGML_USE_HIP)
}

if (itB + 1 < ntB) {
Expand Down Expand Up @@ -473,7 +515,7 @@ static __global__ void mul_mat_f_ids(
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#endif // defined(AMD_WMMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}

template<typename T, int cols_per_block, int nwarps>
Expand Down Expand Up @@ -533,8 +575,13 @@ void mul_mat_f_cuda(
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream, const mmf_ids_data * ids_data) {
#if defined(GGML_USE_HIP)
typedef tile<16, 8, T> tile_A;
typedef tile<16, 8, T> tile_B;
#else
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
#endif // defined(GGML_USE_HIP)

GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
Expand Down
Loading