11#include  " conv2d.cuh" 
22#include  " convert.cuh" 
33
4- #include  < mma.h> 
5- using  namespace  nvcuda ; 
4+ #ifdef  FP16_MMA_AVAILABLE
5+ #    if  !defined(GGML_USE_HIP)
6+ #        include  < mma.h> 
7+ #        ifdef  GGML_USE_MUSA
8+ namespace  wmma  =  mtmusa::wmma;
9+ #        else    //  GGML_USE_MUSA
10+ namespace  wmma  =  nvcuda::wmma;
11+ #        endif   //  GGML_USE_MUSA
12+ #    elif  defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
13+ #        include  < rocwmma/rocwmma.hpp> 
14+ namespace  wmma  =  rocwmma;
15+ #    endif   //  !defined(GGML_USE_HIP)
16+ #endif       //  FP16_MMA_AVAILABLE
617
718struct  conv_params  {
819    const  int64_t  IW, IH;
@@ -111,6 +122,8 @@ class float_mma {
111122    __device__  __forceinline__  float  * store_result () const  { return  buf; }
112123};
113124
125+ #if  (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE)))
126+ 
114127class  half_mma  {
115128  private: 
116129    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float >              acc;
@@ -136,6 +149,42 @@ class half_mma {
136149    }
137150};
138151
152+ #else 
153+ 
154+ class  half_mma  {
155+   public: 
156+     float  * buf;
157+ 
158+     __device__  __forceinline__  half_mma (float  * scratch) {
159+         buf               = scratch;
160+         const  int  lane_id = threadIdx .x  % warpSize ;
161+ #    pragma  unroll
162+         for  (int  i = lane_id; i < WMMA_M * WMMA_N; i += warpSize ) {
163+             buf[i] = 0 .0f ;
164+         }
165+     }
166+ 
167+     __device__  __forceinline__  void  mma (const  half * A_sh, const  half * B_sh, const  int  strideA, const  int  strideB) {
168+         const  int  lane_id = threadIdx .x  % warpSize ;
169+ #    pragma  unroll
170+         for  (int  e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize ) {
171+             int    m   = e / WMMA_N;
172+             int    n   = e % WMMA_N;
173+             float  sum = buf[m * WMMA_N + n];
174+ #    pragma  unroll
175+             for  (int  k = 0 ; k < WMMA_K; k++) {
176+                 float  a = A_sh[m * strideA + k];
177+                 float  b = B_sh[k * strideB + n];
178+                 sum     = fmaf (__half2float (a), __half2float (b), sum);
179+             }
180+             buf[m * WMMA_N + n] = sum;
181+         }
182+     }
183+ 
184+     __device__  __forceinline__  float  * store_result () const  { return  buf; }
185+ };
186+ #endif   //  defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE))
187+ 
139188template  <typename  T, typename  layout, typename  mma>
140189static  __global__  void  conv2d_kernel (const  float  * IN, const  T * IK, float  * OUT, const  conv_params P) {
141190    extern  __shared__  unsigned  char  smem_raw[];
0 commit comments