22#include " ggml.h"
33#include " softmax.cuh"
44#include < cstdint>
5- #include < utility>
65
76template <typename T>
87static __device__ __forceinline__ float t2f32 (T val) {
@@ -14,29 +13,6 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1413 return __half2float (val);
1514}
1615
17- struct soft_max_params {
18-
19- int64_t nheads;
20- uint32_t n_head_log2;
21- int64_t ncols;
22- int64_t nrows_x;
23- int64_t nrows_y;
24- int64_t ne00;
25- int64_t ne01;
26- int64_t ne02;
27- int64_t ne03;
28- int64_t nb11;
29- int64_t nb12;
30- int64_t nb13;
31-
32- int64_t ne12;
33- int64_t ne13;
34- float scale;
35- float max_bias;
36- float m0;
37- float m1;
38- };
39-
4016// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
4117// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
4218#ifdef __clang__
@@ -45,33 +21,25 @@ struct soft_max_params {
4521#endif // __clang__
4622template <bool use_shared, int ncols_template, int block_size_template, typename T>
4723static __global__ void soft_max_f32 (
48- const float * x, const T * mask, float * dst, const soft_max_params p,
24+ const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25+ const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2,
4926 float cap_params0, float cap_params1, bool do_softcap) {
50- const int ncols = ncols_template == 0 ? p. ncols : ncols_template;
27+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
5128
5229 const int tid = threadIdx .x ;
53-
54- const int64_t i03 = blockIdx .z ;
55- const int64_t i02 = blockIdx .y ;
56- const int64_t i01 = blockIdx .x ;
57-
58- // TODO: noncontigous inputs/outputs
59- const int rowx = blockIdx .x + blockIdx .y * gridDim .x + blockIdx .z * gridDim .x * gridDim .y ;
60-
61- const int64_t i11 = i01;
62- const int64_t i12 = i02 % p.ne12 ;
63- const int64_t i13 = i03 % p.ne13 ;
30+ const int rowx = blockIdx .x ;
31+ const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
6432
6533 x += int64_t (rowx)*ncols;
66- mask += (i11*p. nb11 + i12*p. nb12 + i13*p. nb13 ) / sizeof (T) * (mask != nullptr );
34+ mask += int64_t (rowy)*ncols * (mask != nullptr );
6735 dst += int64_t (rowx)*ncols;
6836
6937 const int block_size = block_size_template == 0 ? blockDim .x : block_size_template;
7038
7139 const int warp_id = threadIdx .x / WARP_SIZE;
7240 const int lane_id = threadIdx .x % WARP_SIZE;
7341
74- const float slope = get_alibi_slope (p. max_bias , i02, p. n_head_log2 , p. m0 , p. m1 );
42+ const float slope = get_alibi_slope (max_bias, rowx/nrows_y, n_head_log2, m0, m1);
7543
7644 extern __shared__ float data_soft_max_f32[];
7745 float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -89,16 +57,14 @@ static __global__ void soft_max_f32(
8957 }
9058
9159 const int64_t ix = (int64_t )rowx*ncols + col;
92- // const int64_t iy = (int64_t)rowy*ncols + col;
60+ const int64_t iy = (int64_t )rowy*ncols + col;
9361
9462 // const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
9563
9664 // const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
97-
98- // const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
9965
100- const float val = do_softcap ? p. scale *cap_params1*tanhf (cap_params0*x[ix]) + (mask ? slope*t2f32 (mask[col ]) : 0 .0f ) :
101- x[col]*p. scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
66+ const float val = do_softcap ? scale*cap_params1*tanhf (cap_params0*x[ix]) + (mask ? slope*t2f32 (mask[iy ]) : 0 .0f ) :
67+ x[col]*scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
10268
10369 vals[col] = val;
10470 max_val = max (max_val, val);
@@ -193,62 +159,64 @@ static __global__ void soft_max_back_f32(
193159 }
194160}
195161
196-
197- template <int ... Ns, typename T>
198- static void launch_soft_max_kernels (const float * x, const T * mask, float * dst,
199- const soft_max_params & p, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
200- {
201- const int id = ggml_cuda_get_device ();
202- const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
203-
204- auto launch_kernel = [=](auto I) -> bool {
205- constexpr int ncols = decltype (I)::value;
206- constexpr int block = (ncols > 1024 ? 1024 : ncols);
207-
208- if (p.ncols == ncols) {
209- CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , ncols, block, T>), smpbo);
210- soft_max_f32<true , ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
211- (x, mask, dst, p, cap_params0, cap_params1, do_softcap);
212- return true ;
213- }
214- return false ;
215- };
216-
217- // unary fold over launch_kernel
218- if ((launch_kernel (std::integral_constant<int , Ns>{}) || ...)) {
219- return ;
220- }
221-
222- // default case
223- CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , 0 , 0 , T>), smpbo);
224- soft_max_f32<true , 0 , 0 >
225- <<<block_nums, block_dims, nbytes_shared, stream>>> (x, mask, dst, p, cap_params0, cap_params1, do_softcap);
226- }
227-
228-
229162template <typename T>
230- static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const soft_max_params & params,
231- float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
163+ static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
232164 int nth = WARP_SIZE;
233- const int64_t ncols_x = params.ncols ;
234-
235165 while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
236166 const dim3 block_dims (nth, 1 , 1 );
237- const dim3 block_nums (params. ne01 , params. ne02 , params. ne03 );
167+ const dim3 block_nums (nrows_x, 1 , 1 );
238168 const size_t nbytes_shared = (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof (float );
239169 static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
240170
241- // FIXME: this limit could be raised by ~2-4x on Ampere or newer
242-
243- const int id = ggml_cuda_get_device ();
244- const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
171+ const uint32_t n_head = nrows_x/nrows_y;
172+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
173+
174+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
175+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
245176
246- if (nbytes_shared <= smpbo) {
247- launch_soft_max_kernels<32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(x, mask, dst, params, cap_params0, cap_params1, do_softcap, stream, block_dims, block_nums, nbytes_shared);
177+ // FIXME: this limit could be raised by ~2-4x on Ampere or newer
178+ if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
179+ switch (ncols_x) {
180+ case 32 :
181+ soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
182+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
183+ break ;
184+ case 64 :
185+ soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
186+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
187+ break ;
188+ case 128 :
189+ soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
190+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
191+ break ;
192+ case 256 :
193+ soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
194+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
195+ break ;
196+ case 512 :
197+ soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
198+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
199+ break ;
200+ case 1024 :
201+ soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
202+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
203+ break ;
204+ case 2048 :
205+ soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
206+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
207+ break ;
208+ case 4096 :
209+ soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
210+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
211+ break ;
212+ default :
213+ soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
214+ (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
215+ break ;
216+ }
248217 } else {
249218 const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
250- soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (
251- x, mask, dst, params, cap_params0, cap_params1, do_softcap);
219+ soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
252220 }
253221}
254222
@@ -276,11 +244,10 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
276244
277245 GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
278246
247+ const int64_t ne00 = src0->ne [0 ];
279248 const int64_t nrows_x = ggml_nrows (src0);
280249 const int64_t nrows_y = src0->ne [1 ];
281250
282- const int64_t ne00 = src0->ne [0 ];
283-
284251 float scale = 1 .0f ;
285252 float max_bias = 0 .0f ;
286253
@@ -289,54 +256,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
289256
290257 const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
291258
292- const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
293- const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
294- const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
295-
296- const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
297- const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
298-
299- const uint32_t n_head = src0->ne [2 ];
300- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
301-
302- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
303- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
304-
305-
306- soft_max_params params = {};
307- params.nheads = src0->ne [2 ];
308- params.n_head_log2 = n_head_log2;
309- params.ncols = ne00;
310- params.nrows_x = nrows_x;
311- params.nrows_y = nrows_y;
312- params.ne00 = src0->ne [0 ];
313- params.ne01 = src0->ne [1 ];
314- params.ne02 = src0->ne [2 ];
315- params.ne03 = src0->ne [3 ];
316- params.nb11 = nb11;
317- params.nb12 = nb12;
318- params.nb13 = nb13;
319- params.ne12 = ne12;
320- params.ne13 = ne13;
321- params.scale = scale;
322- params.max_bias = max_bias;
323- params.m0 = m0;
324- params.m1 = m1;
325-
326259 if (use_f16) {
327260 // const half * src1_dd = (const half *)src1_d;
328261
329- // soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
330-
331- soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, params, 0 , 0 , false , stream);
332-
262+ soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0 , 0 , false , stream);
333263 } else {
334-
335264 // const float * src1_dd = (const float *)src1_d;
336265
337- // soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
338-
339- soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, params, 0 , 0 , false , stream);
266+ soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0 , 0 , false , stream);
340267 }
341268}
342269
@@ -355,64 +282,24 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
355282
356283 GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
357284
285+ const int64_t ne00 = src0->ne [0 ];
358286 const int64_t nrows_x = ggml_nrows (src0);
359287 const int64_t nrows_y = src0->ne [1 ];
360288
361- const int64_t ne00 = src0->ne [0 ];
362-
363- float scale = 1 .0f ;
364- float max_bias = 0 .0f ;
365-
366- memcpy (&scale, (const float *) dst->op_params + 0 , sizeof (float ));
367- memcpy (&max_bias, (const float *) dst->op_params + 1 , sizeof (float ));
368-
369- const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
370- const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
371- const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
372-
373- const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
374- const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
375-
376- const uint32_t n_head = src0->ne [2 ];
377- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
378-
379- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
380- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
381-
382- soft_max_params params = {};
383- params.nheads = src0->ne [2 ];
384- params.n_head_log2 = n_head_log2;
385- params.ncols = ne00;
386- params.nrows_x = nrows_x;
387- params.nrows_y = nrows_y;
388- params.ne00 = src0->ne [0 ];
389- params.ne01 = src0->ne [1 ];
390- params.ne02 = src0->ne [2 ];
391- params.ne03 = src0->ne [3 ];
392- params.nb11 = nb11;
393- params.nb12 = nb12;
394- params.nb13 = nb13;
395- params.ne12 = ne12;
396- params.ne13 = ne13;
397- params.scale = scale;
398- params.max_bias = max_bias;
399- params.m0 = m0;
400- params.m1 = m1;
401-
402- // float params[4];
403- // memcpy(params, dst->op_params, sizeof(params));
289+ float params[4 ];
290+ memcpy (params, dst->op_params , sizeof (params));
404291
405292 const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
406293 // printf("%s: %g, %g, %g, %g, %p, %d\n", __func__, params[0], params[1], params[2], params[3], (const void *)src1, use_f16);
407294
408295 if (use_f16) {
409296 const half * src1_dd = (const half *)src1_d;
410297
411- soft_max_f32_cuda (src0_d, src1_dd, dst_d, params, 0 , 0 , true , stream);
298+ soft_max_f32_cuda (src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[ 0 ], params[ 1 ], params[ 2 ], params[ 3 ] , true , stream);
412299 } else {
413300 const float * src1_dd = (const float *)src1_d;
414301
415- soft_max_f32_cuda (src0_d, src1_dd, dst_d, params, 0 , 0 , true , stream);
302+ soft_max_f32_cuda (src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[ 0 ], params[ 1 ], params[ 2 ], params[ 3 ] , true , stream);
416303 }
417304}
418305
0 commit comments