-
Notifications
You must be signed in to change notification settings - Fork 13.7k
HIP: RDNA4 tensor core support for MMF #17077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
2f7cfcf
d564a35
0ec241d
bbee5fe
6b8ceeb
fd18344
7a09e22
c65dd59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) { | |||||||||||||
| #define AMD_MFMA_AVAILABLE | ||||||||||||||
| #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) | ||||||||||||||
|
|
||||||||||||||
| #if defined(GGML_USE_HIP) && defined(RDNA4) | ||||||||||||||
| #define AMD_WMMA_AVAILABLE | ||||||||||||||
| #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) | ||||||||||||||
|
|
||||||||||||||
| // The Volta instructions are in principle available on Turing or newer but they are effectively unusable: | ||||||||||||||
| #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||||||||||||||
| #define VOLTA_MMA_AVAILABLE | ||||||||||||||
|
|
@@ -283,6 +287,14 @@ static bool amd_mfma_available(const int cc) { | |||||||||||||
| #endif //!defined(GGML_HIP_NO_MMQ_MFMA) | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| static bool amd_wmma_available(const int cc) { | ||||||||||||||
| #if !defined(GGML_HIP_NO_WMMA) | ||||||||||||||
| return GGML_CUDA_CC_IS_RDNA4(cc); | ||||||||||||||
| #else | ||||||||||||||
| return false; | ||||||||||||||
| #endif //!defined(GGML_HIP_NO_WMMA) | ||||||||||||||
|
Comment on lines
+291
to
+295
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
As it is this is inconsistent with the check in device code. |
||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| static bool volta_mma_available(const int cc) { | ||||||||||||||
| return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -74,6 +74,33 @@ namespace ggml_cuda_mma { | |||||||||||||||||||||||||
| static constexpr int J = J_; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| #if defined(GGML_USE_HIP) | ||||||||||||||||||||||||||
| #if defined(RDNA4) | ||||||||||||||||||||||||||
| static constexpr int ne = I * J / 32; | ||||||||||||||||||||||||||
| T x[ne] = {0}; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static constexpr __device__ bool supported() { | ||||||||||||||||||||||||||
| if (I == 16 && J == 16) return true; | ||||||||||||||||||||||||||
| return false; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static __device__ __forceinline__ int get_i(const int l) { | ||||||||||||||||||||||||||
| if constexpr (I == 16 && J == 16) { | ||||||||||||||||||||||||||
| return 8 * (threadIdx.x / 16) + l; | ||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||
| NO_DEVICE_CODE; | ||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static __device__ __forceinline__ int get_j(const int l) { | ||||||||||||||||||||||||||
| if constexpr (I == 16 && J == 16) { | ||||||||||||||||||||||||||
| return threadIdx.x % 16; | ||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||
| NO_DEVICE_CODE; | ||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||
| static constexpr int ne = I * J / 64; | ||||||||||||||||||||||||||
| T x[ne] = {0}; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -119,6 +146,7 @@ namespace ggml_cuda_mma { | |||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| #endif // defined(RDNA4) | ||||||||||||||||||||||||||
| #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||||||||||||||||||||||||||
| static constexpr int ne = I * J / 32; | ||||||||||||||||||||||||||
| T x[ne] = {0}; | ||||||||||||||||||||||||||
|
|
@@ -236,6 +264,32 @@ namespace ggml_cuda_mma { | |||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| #elif defined(AMD_WMMA_AVAILABLE) | ||||||||||||||||||||||||||
| static constexpr int ne = I * J / 32; | ||||||||||||||||||||||||||
| half2 x[ne] = {{0.0f, 0.0f}}; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static constexpr __device__ bool supported() { | ||||||||||||||||||||||||||
| if (I == 16 && J == 8) return true; | ||||||||||||||||||||||||||
| return false; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static __device__ __forceinline__ int get_i(const int l) { | ||||||||||||||||||||||||||
| if constexpr (I == 16 && J == 8) { | ||||||||||||||||||||||||||
| return threadIdx.x % 16; | ||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||
| NO_DEVICE_CODE; | ||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static __device__ __forceinline__ int get_j(const int l) { | ||||||||||||||||||||||||||
| if constexpr (I == 16 && J == 8) { | ||||||||||||||||||||||||||
| return 4 * (threadIdx.x / 16) + l; | ||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||
| NO_DEVICE_CODE; | ||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||
| static constexpr int ne = I * J / WARP_SIZE; | ||||||||||||||||||||||||||
| half2 x[ne] = {{0.0f, 0.0f}}; | ||||||||||||||||||||||||||
|
|
@@ -285,6 +339,34 @@ namespace ggml_cuda_mma { | |||||||||||||||||||||||||
| struct tile<I_, J_, nv_bfloat162> { | ||||||||||||||||||||||||||
| static constexpr int I = I_; | ||||||||||||||||||||||||||
| static constexpr int J = J_; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| #if defined(AMD_WMMA_AVAILABLE) | ||||||||||||||||||||||||||
| static constexpr int ne = I * J / 32; | ||||||||||||||||||||||||||
| nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static constexpr __device__ bool supported() { | ||||||||||||||||||||||||||
| if (I == 16 && J == 8) return true; | ||||||||||||||||||||||||||
| return false; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static __device__ __forceinline__ int get_i(const int l) { | ||||||||||||||||||||||||||
| if constexpr (I == 16 && J == 8) { | ||||||||||||||||||||||||||
| return threadIdx.x % 16; | ||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||
| NO_DEVICE_CODE; | ||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static __device__ __forceinline__ int get_j(const int l) { | ||||||||||||||||||||||||||
| if constexpr (I == 16 && J == 8) { | ||||||||||||||||||||||||||
| return 4 * (threadIdx.x / 16) + l; | ||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||
| NO_DEVICE_CODE; | ||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||
| static constexpr int ne = I * J / WARP_SIZE; | ||||||||||||||||||||||||||
| nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -320,6 +402,7 @@ namespace ggml_cuda_mma { | |||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| #endif // defined(AMD_WMMA_AVAILABLE) | ||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| template <int I, int J> | ||||||||||||||||||||||||||
|
|
@@ -353,6 +436,11 @@ namespace ggml_cuda_mma { | |||||||||||||||||||||||||
| const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); | ||||||||||||||||||||||||||
| xi[0] = xs[0]; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| #elif defined(AMD_WMMA_AVAILABLE) | ||||||||||||||||||||||||||
| constexpr int nbytes = sizeof(t.x); | ||||||||||||||||||||||||||
| // Special case for RDNA3 fp16 and bf16 wmma, the size is 32 bytes. | ||||||||||||||||||||||||||
| constexpr int alignment = nbytes > ggml_cuda_get_max_cpy_bytes() ? ggml_cuda_get_max_cpy_bytes() : nbytes; | ||||||||||||||||||||||||||
| ggml_cuda_memcpy_1<sizeof(t.x), alignment>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); | ||||||||||||||||||||||||||
|
Comment on lines
+440
to
+443
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is not how |
||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||||
| for (int l = 0; l < t.ne; ++l) { | ||||||||||||||||||||||||||
|
|
@@ -639,12 +727,44 @@ namespace ggml_cuda_mma { | |||||||||||||||||||||||||
| : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) | ||||||||||||||||||||||||||
| : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); | ||||||||||||||||||||||||||
| #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||||||||||||||||||||||||
| #elif defined(AMD_WMMA_AVAILABLE) | ||||||||||||||||||||||||||
| using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; | ||||||||||||||||||||||||||
| using floatx8_t = __attribute__((ext_vector_type(8))) float; | ||||||||||||||||||||||||||
| floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]); | ||||||||||||||||||||||||||
| const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]); | ||||||||||||||||||||||||||
| const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]); | ||||||||||||||||||||||||||
| acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); | ||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||
| GGML_UNUSED_VARS(D, A, B); | ||||||||||||||||||||||||||
| NO_DEVICE_CODE; | ||||||||||||||||||||||||||
| #endif // TURING_MMA_AVAILABLE | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static __device__ __forceinline__ void mma( | ||||||||||||||||||||||||||
| tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { | ||||||||||||||||||||||||||
| #ifdef AMPERE_MMA_AVAILABLE | ||||||||||||||||||||||||||
| const int * Axi = (const int *) A.x; | ||||||||||||||||||||||||||
| const int * Bxi = (const int *) B.x; | ||||||||||||||||||||||||||
| int * Dxi = (int *) D.x; | ||||||||||||||||||||||||||
| asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" | ||||||||||||||||||||||||||
| : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) | ||||||||||||||||||||||||||
| : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); | ||||||||||||||||||||||||||
| asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" | ||||||||||||||||||||||||||
| : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) | ||||||||||||||||||||||||||
| : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); | ||||||||||||||||||||||||||
| #elif defined(AMD_WMMA_AVAILABLE) | ||||||||||||||||||||||||||
|
Comment on lines
+745
to
+755
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Unless you tested this code and asserted that it actually works correctly for NVIDIA it should be removed. |
||||||||||||||||||||||||||
| using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; | ||||||||||||||||||||||||||
| using floatx8_t = __attribute__((ext_vector_type(8))) float; | ||||||||||||||||||||||||||
| floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]); | ||||||||||||||||||||||||||
| const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]); | ||||||||||||||||||||||||||
| const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]); | ||||||||||||||||||||||||||
| acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); | ||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||
| GGML_UNUSED_VARS(D, A, B); | ||||||||||||||||||||||||||
| NO_DEVICE_CODE; | ||||||||||||||||||||||||||
| #endif // AMPERE_MMA_AVAILABLE | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| static __device__ __forceinline__ void mma( | ||||||||||||||||||||||||||
| tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { | ||||||||||||||||||||||||||
| #if defined(AMD_MFMA_AVAILABLE) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |||||||
|
|
||||||||
| #include "mma.cuh" | ||||||||
| #include "common.cuh" | ||||||||
| #include "convert.cuh" | ||||||||
|
|
||||||||
| using namespace ggml_cuda_mma; | ||||||||
|
|
||||||||
|
|
@@ -27,20 +28,34 @@ static __global__ void mul_mat_f( | |||||||
| const int stride_col_id, const int stride_row_id, | ||||||||
| const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, | ||||||||
| const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { | ||||||||
| #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) | ||||||||
| #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) | ||||||||
| #if defined(AMD_WMMA_AVAILABLE) | ||||||||
| // Special case for tf32, just dummy mma layout as wmma doesn't support it. | ||||||||
| constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16; | ||||||||
| constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16; | ||||||||
| typedef tile<16, 8, T> tile_A; | ||||||||
| typedef tile<tile_B_I, 8, T> tile_B; | ||||||||
| typedef tile<16, tile_C_J, float> tile_C; | ||||||||
|
|
||||||||
| constexpr bool a_supported = tile_A::supported(); | ||||||||
| constexpr bool b_supported = tile_B::supported(); | ||||||||
| constexpr bool c_supported = tile_C::supported(); | ||||||||
| constexpr bool supported = a_supported && b_supported && c_supported; | ||||||||
| #else | ||||||||
| constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); | ||||||||
| constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); | ||||||||
|
|
||||||||
| if (!I_16_supported && !I_32_supported) { | ||||||||
| NO_DEVICE_CODE; | ||||||||
| return; | ||||||||
| } | ||||||||
| constexpr bool supported = I_16_supported || I_32_supported; | ||||||||
|
|
||||||||
| constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. | ||||||||
|
|
||||||||
| typedef tile<I_preferred, 8, T> tile_A; | ||||||||
| typedef tile<8, 8, T> tile_B; | ||||||||
| typedef tile<I_preferred, 8, float> tile_C; | ||||||||
| #endif // defined(AMD_WMMA_AVAILABLE) | ||||||||
| if constexpr (!supported) { | ||||||||
| NO_DEVICE_CODE; | ||||||||
| return; | ||||||||
| } | ||||||||
|
|
||||||||
| constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||||||
| constexpr int tile_k_padded = warp_size + 4; | ||||||||
|
|
@@ -161,11 +176,19 @@ static __global__ void mul_mat_f( | |||||||
|
|
||||||||
| if constexpr (!has_ids) { | ||||||||
| const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); | ||||||||
| #if !defined(GGML_USE_HIP) | ||||||||
| tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; | ||||||||
| #else | ||||||||
| tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T, float2>(tmp); | ||||||||
| #endif // !defined(GGML_USE_HIP) | ||||||||
| } else { | ||||||||
| const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; | ||||||||
| float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); | ||||||||
| #if !defined(GGML_USE_HIP) | ||||||||
| tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; | ||||||||
| #else | ||||||||
| tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T, float2>(tmp); | ||||||||
| #endif // !defined(GGML_USE_HIP) | ||||||||
| } | ||||||||
| } | ||||||||
| } else { | ||||||||
|
|
@@ -239,7 +262,7 @@ static __global__ void mul_mat_f( | |||||||
| channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||||||
| sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||||||
| NO_DEVICE_CODE; | ||||||||
| #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) | ||||||||
| #endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) | ||||||||
| } | ||||||||
|
|
||||||||
| //This kernel is for larger batch sizes of mul_mat_id | ||||||||
|
|
@@ -253,20 +276,34 @@ static __global__ void mul_mat_f_ids( | |||||||
| const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, | ||||||||
| const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, | ||||||||
| const uint3 sis1_fd, const uint3 nch_fd) { | ||||||||
| #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) | ||||||||
| #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
To be clear: this is a TODO for me, from your side no action is necessary. |
||||||||
| #if defined(AMD_WMMA_AVAILABLE) | ||||||||
| // Special case for tf32, just dummy mma layout as wmma doesn't support it. | ||||||||
| constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16; | ||||||||
| constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16; | ||||||||
| typedef tile<16, 8, T> tile_A; | ||||||||
| typedef tile<tile_B_I, 8, T> tile_B; | ||||||||
| typedef tile<16, tile_C_J, float> tile_C; | ||||||||
|
|
||||||||
| constexpr bool a_supported = tile_A::supported(); | ||||||||
| constexpr bool b_supported = tile_B::supported(); | ||||||||
| constexpr bool c_supported = tile_C::supported(); | ||||||||
| constexpr bool supported = a_supported && b_supported && c_supported; | ||||||||
| #else | ||||||||
| constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); | ||||||||
| constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); | ||||||||
| constexpr bool supported = I_16_supported || I_32_supported; | ||||||||
|
|
||||||||
| if (!I_16_supported && !I_32_supported) { | ||||||||
| NO_DEVICE_CODE; | ||||||||
| return; | ||||||||
| } | ||||||||
|
|
||||||||
| constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster. | ||||||||
| constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. | ||||||||
|
|
||||||||
| typedef tile<I_preferred, 8, T> tile_A; | ||||||||
| typedef tile<8, 8, T> tile_B; | ||||||||
| typedef tile<I_preferred, 8, float> tile_C; | ||||||||
| #endif // defined(AMD_WMMA_AVAILABLE) | ||||||||
| if constexpr (!supported) { | ||||||||
| NO_DEVICE_CODE; | ||||||||
| return; | ||||||||
| } | ||||||||
|
|
||||||||
| constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||||||
| constexpr int tile_k_padded = warp_size + 4; | ||||||||
|
|
@@ -408,7 +445,11 @@ static __global__ void mul_mat_f_ids( | |||||||
| #pragma unroll | ||||||||
| for (int j0 = 0; j0 < tile_B::I; ++j0) { | ||||||||
| const float2 tmp = vals_buf[curr_buf][j0]; | ||||||||
| #if !defined(GGML_USE_HIP) | ||||||||
| tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; | ||||||||
| #else | ||||||||
| tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T, float2>(tmp); | ||||||||
| #endif // !defined(GGML_USE_HIP) | ||||||||
|
Comment on lines
+448
to
+452
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||||||||
| } | ||||||||
|
|
||||||||
| if (itB + 1 < ntB) { | ||||||||
|
|
@@ -492,7 +533,7 @@ static __global__ void mul_mat_f_ids( | |||||||
| channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||||||
| sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); | ||||||||
| NO_DEVICE_CODE; | ||||||||
| #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) | ||||||||
| #endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) | ||||||||
| } | ||||||||
|
|
||||||||
| template<typename T, int cols_per_block, int nwarps> | ||||||||
|
|
@@ -554,7 +595,8 @@ void mul_mat_f_cuda( | |||||||
| cudaStream_t stream, const mmf_ids_data * ids_data) { | ||||||||
| typedef tile<16, 8, T> tile_A_16; | ||||||||
| typedef tile<32, 8, T> tile_A_32; | ||||||||
| typedef tile< 8, 8, T> tile_B; | ||||||||
| typedef tile<16, 8, T> tile_B_16; | ||||||||
| typedef tile< 8, 8, T> tile_B_8; | ||||||||
|
|
||||||||
| GGML_ASSERT(ncols_x % 2 == 0); | ||||||||
| GGML_ASSERT(stride_row % 2 == 0); | ||||||||
|
|
@@ -581,7 +623,8 @@ void mul_mat_f_cuda( | |||||||
|
|
||||||||
| constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; | ||||||||
| const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; | ||||||||
| const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; | ||||||||
| const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I; | ||||||||
| const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4; | ||||||||
| const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); | ||||||||
| const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; | ||||||||
| const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.