Skip to content

Commit 9118bcf

Browse files
committed
Update softmax.cu
1 parent b73633d commit 9118bcf

File tree

1 file changed

+66
-179
lines changed

1 file changed

+66
-179
lines changed

ggml/src/ggml-cuda/softmax.cu

Lines changed: 66 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include "ggml.h"
33
#include "softmax.cuh"
44
#include <cstdint>
5-
#include <utility>
65

76
template <typename T>
87
static __device__ __forceinline__ float t2f32(T val) {
@@ -14,29 +13,6 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1413
return __half2float(val);
1514
}
1615

17-
struct soft_max_params {
18-
19-
int64_t nheads;
20-
uint32_t n_head_log2;
21-
int64_t ncols;
22-
int64_t nrows_x;
23-
int64_t nrows_y;
24-
int64_t ne00;
25-
int64_t ne01;
26-
int64_t ne02;
27-
int64_t ne03;
28-
int64_t nb11;
29-
int64_t nb12;
30-
int64_t nb13;
31-
32-
int64_t ne12;
33-
int64_t ne13;
34-
float scale;
35-
float max_bias;
36-
float m0;
37-
float m1;
38-
};
39-
4016
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
4117
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
4218
#ifdef __clang__
@@ -45,33 +21,25 @@ struct soft_max_params {
4521
#endif // __clang__
4622
template <bool use_shared, int ncols_template, int block_size_template, typename T>
4723
static __global__ void soft_max_f32(
48-
const float * x, const T * mask, float * dst, const soft_max_params p,
24+
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25+
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2,
4926
float cap_params0, float cap_params1, bool do_softcap) {
50-
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
27+
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
5128

5229
const int tid = threadIdx.x;
53-
54-
const int64_t i03 = blockIdx.z;
55-
const int64_t i02 = blockIdx.y;
56-
const int64_t i01 = blockIdx.x;
57-
58-
//TODO: noncontigous inputs/outputs
59-
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
60-
61-
const int64_t i11 = i01;
62-
const int64_t i12 = i02 % p.ne12;
63-
const int64_t i13 = i03 % p.ne13;
30+
const int rowx = blockIdx.x;
31+
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
6432

6533
x += int64_t(rowx)*ncols;
66-
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
34+
mask += int64_t(rowy)*ncols * (mask != nullptr);
6735
dst += int64_t(rowx)*ncols;
6836

6937
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
7038

7139
const int warp_id = threadIdx.x / WARP_SIZE;
7240
const int lane_id = threadIdx.x % WARP_SIZE;
7341

74-
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
42+
const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
7543

7644
extern __shared__ float data_soft_max_f32[];
7745
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -89,16 +57,14 @@ static __global__ void soft_max_f32(
8957
}
9058

9159
const int64_t ix = (int64_t)rowx*ncols + col;
92-
// const int64_t iy = (int64_t)rowy*ncols + col;
60+
const int64_t iy = (int64_t)rowy*ncols + col;
9361

9462
// const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
9563

9664
// const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
97-
98-
// const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
9965

100-
const float val = do_softcap ? p.scale*cap_params1*tanhf(cap_params0*x[ix]) + (mask ? slope*t2f32(mask[col]) : 0.0f) :
101-
x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
66+
const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + (mask ? slope*t2f32(mask[iy]) : 0.0f) :
67+
x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
10268

10369
vals[col] = val;
10470
max_val = max(max_val, val);
@@ -193,62 +159,64 @@ static __global__ void soft_max_back_f32(
193159
}
194160
}
195161

196-
197-
template<int... Ns, typename T>
198-
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
199-
const soft_max_params & p, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
200-
{
201-
const int id = ggml_cuda_get_device();
202-
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
203-
204-
auto launch_kernel = [=](auto I) -> bool {
205-
constexpr int ncols = decltype(I)::value;
206-
constexpr int block = (ncols > 1024 ? 1024 : ncols);
207-
208-
if (p.ncols == ncols) {
209-
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
210-
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
211-
(x, mask, dst, p, cap_params0, cap_params1, do_softcap);
212-
return true;
213-
}
214-
return false;
215-
};
216-
217-
// unary fold over launch_kernel
218-
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
219-
return;
220-
}
221-
222-
//default case
223-
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
224-
soft_max_f32<true, 0, 0>
225-
<<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p, cap_params0, cap_params1, do_softcap);
226-
}
227-
228-
229162
template<typename T>
230-
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params,
231-
float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
163+
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
232164
int nth = WARP_SIZE;
233-
const int64_t ncols_x = params.ncols;
234-
235165
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
236166
const dim3 block_dims(nth, 1, 1);
237-
const dim3 block_nums(params.ne01, params.ne02, params.ne03);
167+
const dim3 block_nums(nrows_x, 1, 1);
238168
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
239169
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
240170

241-
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
242-
243-
const int id = ggml_cuda_get_device();
244-
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
171+
const uint32_t n_head = nrows_x/nrows_y;
172+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
173+
174+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
175+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
245176

246-
if (nbytes_shared <= smpbo) {
247-
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, cap_params0, cap_params1, do_softcap, stream, block_dims, block_nums, nbytes_shared);
177+
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
178+
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
179+
switch (ncols_x) {
180+
case 32:
181+
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
182+
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
183+
break;
184+
case 64:
185+
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
186+
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
187+
break;
188+
case 128:
189+
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
190+
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
191+
break;
192+
case 256:
193+
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
194+
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
195+
break;
196+
case 512:
197+
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
198+
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
199+
break;
200+
case 1024:
201+
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
202+
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
203+
break;
204+
case 2048:
205+
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
206+
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
207+
break;
208+
case 4096:
209+
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
210+
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
211+
break;
212+
default:
213+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
214+
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
215+
break;
216+
}
248217
} else {
249218
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
250-
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(
251-
x, mask, dst, params, cap_params0, cap_params1, do_softcap);
219+
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
252220
}
253221
}
254222

@@ -276,11 +244,10 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
276244

277245
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
278246

247+
const int64_t ne00 = src0->ne[0];
279248
const int64_t nrows_x = ggml_nrows(src0);
280249
const int64_t nrows_y = src0->ne[1];
281250

282-
const int64_t ne00 = src0->ne[0];
283-
284251
float scale = 1.0f;
285252
float max_bias = 0.0f;
286253

@@ -289,54 +256,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
289256

290257
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
291258

292-
const int64_t nb11 = src1 ? src1->nb[1] : 1;
293-
const int64_t nb12 = src1 ? src1->nb[2] : 1;
294-
const int64_t nb13 = src1 ? src1->nb[3] : 1;
295-
296-
const int64_t ne12 = src1 ? src1->ne[2] : 1;
297-
const int64_t ne13 = src1 ? src1->ne[3] : 1;
298-
299-
const uint32_t n_head = src0->ne[2];
300-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
301-
302-
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
303-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
304-
305-
306-
soft_max_params params = {};
307-
params.nheads = src0->ne[2];
308-
params.n_head_log2 = n_head_log2;
309-
params.ncols = ne00;
310-
params.nrows_x = nrows_x;
311-
params.nrows_y = nrows_y;
312-
params.ne00 = src0->ne[0];
313-
params.ne01 = src0->ne[1];
314-
params.ne02 = src0->ne[2];
315-
params.ne03 = src0->ne[3];
316-
params.nb11 = nb11;
317-
params.nb12 = nb12;
318-
params.nb13 = nb13;
319-
params.ne12 = ne12;
320-
params.ne13 = ne13;
321-
params.scale = scale;
322-
params.max_bias = max_bias;
323-
params.m0 = m0;
324-
params.m1 = m1;
325-
326259
if (use_f16) {
327260
// const half * src1_dd = (const half *)src1_d;
328261

329-
// soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
330-
331-
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, 0, 0, false, stream);
332-
262+
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
333263
} else {
334-
335264
// const float * src1_dd = (const float *)src1_d;
336265

337-
// soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
338-
339-
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, 0, 0, false, stream);
266+
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
340267
}
341268
}
342269

@@ -355,64 +282,24 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
355282

356283
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
357284

285+
const int64_t ne00 = src0->ne[0];
358286
const int64_t nrows_x = ggml_nrows(src0);
359287
const int64_t nrows_y = src0->ne[1];
360288

361-
const int64_t ne00 = src0->ne[0];
362-
363-
float scale = 1.0f;
364-
float max_bias = 0.0f;
365-
366-
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
367-
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
368-
369-
const int64_t nb11 = src1 ? src1->nb[1] : 1;
370-
const int64_t nb12 = src1 ? src1->nb[2] : 1;
371-
const int64_t nb13 = src1 ? src1->nb[3] : 1;
372-
373-
const int64_t ne12 = src1 ? src1->ne[2] : 1;
374-
const int64_t ne13 = src1 ? src1->ne[3] : 1;
375-
376-
const uint32_t n_head = src0->ne[2];
377-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
378-
379-
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
380-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
381-
382-
soft_max_params params = {};
383-
params.nheads = src0->ne[2];
384-
params.n_head_log2 = n_head_log2;
385-
params.ncols = ne00;
386-
params.nrows_x = nrows_x;
387-
params.nrows_y = nrows_y;
388-
params.ne00 = src0->ne[0];
389-
params.ne01 = src0->ne[1];
390-
params.ne02 = src0->ne[2];
391-
params.ne03 = src0->ne[3];
392-
params.nb11 = nb11;
393-
params.nb12 = nb12;
394-
params.nb13 = nb13;
395-
params.ne12 = ne12;
396-
params.ne13 = ne13;
397-
params.scale = scale;
398-
params.max_bias = max_bias;
399-
params.m0 = m0;
400-
params.m1 = m1;
401-
402-
// float params[4];
403-
// memcpy(params, dst->op_params, sizeof(params));
289+
float params[4];
290+
memcpy(params, dst->op_params, sizeof(params));
404291

405292
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
406293
//printf("%s: %g, %g, %g, %g, %p, %d\n", __func__, params[0], params[1], params[2], params[3], (const void *)src1, use_f16);
407294

408295
if (use_f16) {
409296
const half * src1_dd = (const half *)src1_d;
410297

411-
soft_max_f32_cuda(src0_d, src1_dd, dst_d, params, 0, 0, true, stream);
298+
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
412299
} else {
413300
const float * src1_dd = (const float *)src1_d;
414301

415-
soft_max_f32_cuda(src0_d, src1_dd, dst_d, params, 0, 0, true, stream);
302+
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
416303
}
417304
}
418305

0 commit comments

Comments
 (0)