Skip to content

Commit 57aa09e

Browse files
committed
CUDA: conv2d support fp16 without wmma
1 parent 96db627 commit 57aa09e

File tree

1 file changed

+51
-2
lines changed

1 file changed

+51
-2
lines changed

ggml/src/ggml-cuda/conv2d.cu

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
#include "conv2d.cuh"
22
#include "convert.cuh"
33

4-
#include <mma.h>
5-
using namespace nvcuda;
4+
#ifdef FP16_MMA_AVAILABLE
5+
# if !defined(GGML_USE_HIP)
6+
# include <mma.h>
7+
# ifdef GGML_USE_MUSA
8+
namespace wmma = mtmusa::wmma;
9+
# else // GGML_USE_MUSA
10+
namespace wmma = nvcuda::wmma;
11+
# endif // GGML_USE_MUSA
12+
# elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
13+
# include <rocwmma/rocwmma.hpp>
14+
namespace wmma = rocwmma;
15+
# endif // !defined(GGML_USE_HIP)
16+
#endif // FP16_MMA_AVAILABLE
617

718
struct conv_params {
819
const int64_t IW, IH;
@@ -111,6 +122,8 @@ class float_mma {
111122
__device__ __forceinline__ float * store_result() const { return buf; }
112123
};
113124

125+
#if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE)))
126+
114127
class half_mma {
115128
private:
116129
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc;
@@ -136,6 +149,42 @@ class half_mma {
136149
}
137150
};
138151

152+
#else
153+
154+
class half_mma {
155+
public:
156+
float * buf;
157+
158+
__device__ __forceinline__ half_mma(float * scratch) {
159+
buf = scratch;
160+
const int lane_id = threadIdx.x % warpSize;
161+
# pragma unroll
162+
for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) {
163+
buf[i] = 0.0f;
164+
}
165+
}
166+
167+
__device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) {
168+
const int lane_id = threadIdx.x % warpSize;
169+
# pragma unroll
170+
for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) {
171+
int m = e / WMMA_N;
172+
int n = e % WMMA_N;
173+
float sum = buf[m * WMMA_N + n];
174+
# pragma unroll
175+
for (int k = 0; k < WMMA_K; k++) {
176+
float a = A_sh[m * strideA + k];
177+
float b = B_sh[k * strideB + n];
178+
sum = fmaf(__half2float(a), __half2float(b), sum);
179+
}
180+
buf[m * WMMA_N + n] = sum;
181+
}
182+
}
183+
184+
__device__ __forceinline__ float * store_result() const { return buf; }
185+
};
186+
#endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE))
187+
139188
template <typename T, typename layout, typename mma>
140189
static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT, const conv_params P) {
141190
extern __shared__ unsigned char smem_raw[];

0 commit comments

Comments
 (0)