Skip to content

Commit 8898040

Browse files
committed
CUDA: use fastdiv + ggml_cuda_mad for mmvf
1 parent 477a66b commit 8898040

File tree

1 file changed

+31
-28
lines changed

1 file changed

+31
-28
lines changed

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
#include "ggml.h"
23
#include "common.cuh"
34
#include "convert.cuh"
@@ -7,14 +8,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
78
static __global__ void mul_mat_vec_f(
89
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
910
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
10-
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
11-
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
11+
const uint3 channel_ratio_fd, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
12+
const uint3 sample_ratio_fd, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
1213
const int row = blockIdx.x;
1314
const int channel_dst = blockIdx.y;
14-
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
15+
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio_fd);
1516
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
1617
const int sample_dst = blockIdx.z;
17-
const int sample_x = sample_dst / sample_ratio;
18+
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio_fd);
1819
const int sample_y = sample_dst;
1920
const int tid = threadIdx.x;
2021

@@ -47,8 +48,8 @@ static __global__ void mul_mat_vec_f(
4748
#pragma unroll
4849
for (int j = 0; j < ncols_dst; ++j) {
4950
const float2 tmpy = y2[j*stride_col_y2 + col2];
50-
sumf[j] += tmpx.x*tmpy.x;
51-
sumf[j] += tmpx.y*tmpy.y;
51+
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
52+
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
5253
}
5354
}
5455
} else if constexpr (std::is_same_v<T, half>) {
@@ -61,8 +62,8 @@ static __global__ void mul_mat_vec_f(
6162
#pragma unroll
6263
for (int j = 0; j < ncols_dst; ++j) {
6364
const float2 tmpy = y2[j*stride_col_y2 + col2];
64-
sumf[j] += tmpx.x * tmpy.x;
65-
sumf[j] += tmpx.y * tmpy.y;
65+
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
66+
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
6667
}
6768
}
6869
} else {
@@ -94,8 +95,10 @@ static __global__ void mul_mat_vec_f(
9495
#pragma unroll
9596
for (int j = 0; j < ncols_dst; ++j) {
9697
const float2 tmpy = y2[j*stride_col_y2 + col2];
97-
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
98-
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
98+
const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
99+
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
100+
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
101+
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
99102
}
100103
}
101104
} else {
@@ -140,8 +143,8 @@ static void launch_mul_mat_vec_f_cuda(
140143
GGML_ASSERT(stride_col_y % 2 == 0);
141144
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
142145
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
143-
const int64_t channel_ratio = nchannels_dst / nchannels_x;
144-
const int64_t sample_ratio = nsamples_dst / nsamples_x;
146+
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
147+
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
145148

146149
const int device = ggml_cuda_get_device();
147150
const int warp_size = ggml_cuda_info().devices[device].warp_size;
@@ -167,50 +170,50 @@ static void launch_mul_mat_vec_f_cuda(
167170
case 32: {
168171
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
169172
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
170-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
171-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
173+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
174+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
172175
} break;
173176
case 64: {
174177
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
175178
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
176-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
177-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
179+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
180+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
178181
} break;
179182
case 96: {
180183
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
181184
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
182-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
183-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
185+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
186+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
184187
} break;
185188
case 128: {
186189
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
187190
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
188-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
189-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
191+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
192+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
190193
} break;
191194
case 160: {
192195
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
193196
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
194-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
195-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
197+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
198+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
196199
} break;
197200
case 192: {
198201
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
199202
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
200-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
201-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
203+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
204+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
202205
} break;
203206
case 224: {
204207
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
205208
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
206-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
207-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
209+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
210+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
208211
} break;
209212
case 256: {
210213
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
211214
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
212-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
213-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
215+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
216+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
214217
} break;
215218
default: {
216219
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)