Skip to content

musa: enable fp16 mma (all) and cublas on qy2 #13842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,9 @@
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)

// Moore Threads
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)

#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD

#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
Expand Down Expand Up @@ -203,9 +201,9 @@ typedef float2 dfloat2;
#define FAST_FP16_AVAILABLE
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
#define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)

#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
#define FP16_MMA_AVAILABLE
Expand All @@ -219,9 +217,9 @@ typedef float2 dfloat2;
#define CP_ASYNC_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
#define FLASH_ATTN_AVAILABLE
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)

static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
Expand All @@ -233,7 +231,8 @@ static bool fast_fp16_available(const int cc) {

// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fast_fp16_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
}

// Any FP16 tensor core instructions are available for ggml code.
Expand All @@ -242,7 +241,8 @@ static bool fp16_mma_available(const int cc) {
return false;
#else
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
GGML_CUDA_CC_IS_MTHREADS(cc)) {
return true;
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
Expand All @@ -259,7 +259,8 @@ static bool fp16_mma_available(const int cc) {
// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fp16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
}

static bool bf16_mma_hardware_available(const int cc) {
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
#ifdef FP16_MMA_AVAILABLE
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#include <mma.h>
#ifdef GGML_USE_MUSA
namespace wmma = mtmusa::wmma;
#else // GGML_USE_MUSA
namespace wmma = nvcuda::wmma;
#endif // GGML_USE_MUSA
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
#include <rocwmma/rocwmma.hpp>
Expand Down
25 changes: 15 additions & 10 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1227,9 +1227,12 @@ static void ggml_cuda_op_mul_mat_cublas(

const int cc = ggml_cuda_info().devices[id].cc;

const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);

const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;

if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
if (src1->type != GGML_TYPE_BF16) {
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
Expand Down Expand Up @@ -1257,7 +1260,7 @@ static void ggml_cuda_op_mul_mat_cublas(

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
} else if (fast_fp16_hardware_available(cc) && use_fp16) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
Expand Down Expand Up @@ -3061,9 +3064,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
#ifdef GGML_USE_MUSA
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
return false;
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
return false;
}
if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
return false;
}
}
#endif // GGML_USE_MUSA
switch (a->type) {
Expand All @@ -3090,11 +3100,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_BF16:
#ifdef GGML_USE_MUSA
if (a->type == GGML_TYPE_Q3_K) {
return false;
}
#endif // GGML_USE_MUSA
return true;
default:
return false;
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-musa/mudnn.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "../include/ggml.h"
#include "../ggml-cuda/common.cuh"
#include "ggml-cuda/common.cuh"
#include "ggml.h"

// Asynchronously copies data from src tensor to dst tensor using the provided context.
// Returns a musaError_t indicating success or failure.
Expand Down
Loading