1+
12#include " ggml.h"
23#include " common.cuh"
34#include " convert.cuh"
@@ -7,14 +8,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
78static __global__ void mul_mat_vec_f (
89 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
910 const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
10- const int channel_ratio , const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
11- const int sample_ratio , const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
11+ const uint3 channel_ratio_fd , const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
12+ const uint3 sample_ratio_fd , const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
1213 const int row = blockIdx .x ;
1314 const int channel_dst = blockIdx .y ;
14- const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio ;
15+ const int channel_x = ids ? ids[channel_dst] : fastdiv (( uint32_t ) channel_dst, channel_ratio_fd) ;
1516 const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
1617 const int sample_dst = blockIdx .z ;
17- const int sample_x = sample_dst / sample_ratio ;
18+ const int sample_x = fastdiv (( uint32_t ) sample_dst, sample_ratio_fd) ;
1819 const int sample_y = sample_dst;
1920 const int tid = threadIdx .x ;
2021
@@ -47,8 +48,8 @@ static __global__ void mul_mat_vec_f(
4748#pragma unroll
4849 for (int j = 0 ; j < ncols_dst; ++j) {
4950 const float2 tmpy = y2[j*stride_col_y2 + col2];
50- sumf[j] += tmpx.x * tmpy.x ;
51- sumf[j] += tmpx.y * tmpy.y ;
51+ ggml_cuda_mad ( sumf[j], tmpx.x , tmpy.x ) ;
52+ ggml_cuda_mad ( sumf[j], tmpx.y , tmpy.y ) ;
5253 }
5354 }
5455 } else if constexpr (std::is_same_v<T, half>) {
@@ -61,8 +62,8 @@ static __global__ void mul_mat_vec_f(
6162#pragma unroll
6263 for (int j = 0 ; j < ncols_dst; ++j) {
6364 const float2 tmpy = y2[j*stride_col_y2 + col2];
64- sumf[j] += tmpx.x * tmpy.x ;
65- sumf[j] += tmpx.y * tmpy.y ;
65+ ggml_cuda_mad ( sumf[j], tmpx.x , tmpy.x ) ;
66+ ggml_cuda_mad ( sumf[j], tmpx.y , tmpy.y ) ;
6667 }
6768 }
6869 } else {
@@ -94,8 +95,10 @@ static __global__ void mul_mat_vec_f(
9495#pragma unroll
9596 for (int j = 0 ; j < ncols_dst; ++j) {
9697 const float2 tmpy = y2[j*stride_col_y2 + col2];
97- sumf[j] += ggml_cuda_cast<float >(reinterpret_cast <const nv_bfloat16 *>(&tmpx)[0 ]) * tmpy.x ;
98- sumf[j] += ggml_cuda_cast<float >(reinterpret_cast <const nv_bfloat16 *>(&tmpx)[1 ]) * tmpy.y ;
98+ const float tmpx0 = ggml_cuda_cast<float >(reinterpret_cast <const nv_bfloat16 *>(&tmpx)[0 ]);
99+ const float tmpx1 = ggml_cuda_cast<float >(reinterpret_cast <const nv_bfloat16 *>(&tmpx)[1 ]);
100+ ggml_cuda_mad (sumf[j], tmpx0, tmpy.x );
101+ ggml_cuda_mad (sumf[j], tmpx1, tmpy.y );
99102 }
100103 }
101104 } else {
@@ -140,8 +143,8 @@ static void launch_mul_mat_vec_f_cuda(
140143 GGML_ASSERT (stride_col_y % 2 == 0 );
141144 GGML_ASSERT (ids || nchannels_dst % nchannels_x == 0 );
142145 GGML_ASSERT ( nsamples_dst % nsamples_x == 0 );
143- const int64_t channel_ratio = nchannels_dst / nchannels_x;
144- const int64_t sample_ratio = nsamples_dst / nsamples_x;
146+ const uint3 channel_ratio_fd = ids ? make_uint3 ( 0 , 0 , 0 ) : init_fastdiv_values ( nchannels_dst / nchannels_x) ;
147+ const uint3 sample_ratio_fd = init_fastdiv_values ( nsamples_dst / nsamples_x) ;
145148
146149 const int device = ggml_cuda_get_device ();
147150 const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
@@ -167,50 +170,50 @@ static void launch_mul_mat_vec_f_cuda(
167170 case 32 : {
168171 mul_mat_vec_f<T, type_acc, ncols_dst, 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
169172 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
170- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
171- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
173+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
174+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
172175 } break ;
173176 case 64 : {
174177 mul_mat_vec_f<T, type_acc, ncols_dst, 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
175178 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
176- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
177- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
179+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
180+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
178181 } break ;
179182 case 96 : {
180183 mul_mat_vec_f<T, type_acc, ncols_dst, 96 ><<<block_nums, block_dims, nbytes_shared, stream>>>
181184 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
182- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
183- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
185+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
186+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
184187 } break ;
185188 case 128 : {
186189 mul_mat_vec_f<T, type_acc, ncols_dst, 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
187190 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
188- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
189- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
191+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
192+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
190193 } break ;
191194 case 160 : {
192195 mul_mat_vec_f<T, type_acc, ncols_dst, 160 ><<<block_nums, block_dims, nbytes_shared, stream>>>
193196 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
194- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
195- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
197+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
198+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
196199 } break ;
197200 case 192 : {
198201 mul_mat_vec_f<T, type_acc, ncols_dst, 192 ><<<block_nums, block_dims, nbytes_shared, stream>>>
199202 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
200- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
201- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
203+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
204+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
202205 } break ;
203206 case 224 : {
204207 mul_mat_vec_f<T, type_acc, ncols_dst, 224 ><<<block_nums, block_dims, nbytes_shared, stream>>>
205208 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
206- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
207- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
209+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
210+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
208211 } break ;
209212 case 256 : {
210213 mul_mat_vec_f<T, type_acc, ncols_dst, 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
211214 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
212- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
213- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
215+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
216+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
214217 } break ;
215218 default : {
216219 GGML_ABORT (" fatal error" );
0 commit comments