@@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1313 return __half2float (val);
1414}
1515
16+ struct soft_max_params {
17+
18+ int64_t nheads;
19+ uint32_t n_head_log2;
20+ int64_t ncols;
21+ int64_t nrows_x;
22+ int64_t nrows_y;
23+ int64_t ne00;
24+ int64_t ne01;
25+ int64_t ne02;
26+ int64_t ne03;
27+ int64_t nb11;
28+ int64_t nb12;
29+ int64_t nb13;
30+
31+ int64_t ne12;
32+ int64_t ne13;
33+ float scale;
34+ float max_bias;
35+ float m0;
36+ float m1;
37+ };
38+
1639// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
1740// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
1841#ifdef __clang__
@@ -21,24 +44,32 @@ __device__ float __forceinline__ t2f32<half>(half val) {
2144#endif // __clang__
2245template <bool use_shared, int ncols_template, int block_size_template, typename T>
2346static __global__ void soft_max_f32 (
24- const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25- const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
26- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
47+ const float * x, const T * mask, float * dst, const soft_max_params p) {
48+ const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
2749
2850 const int tid = threadIdx .x ;
29- const int rowx = blockIdx .x ;
30- const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
51+
52+ const int64_t i03 = blockIdx .z ;
53+ const int64_t i02 = blockIdx .y ;
54+ const int64_t i01 = blockIdx .x ;
55+
56+ // TODO: noncontigous inputs/outputs
57+ const int rowx = blockIdx .x + blockIdx .y * gridDim .x + blockIdx .z * gridDim .x * gridDim .y ;
58+
59+ const int64_t i11 = i01;
60+ const int64_t i12 = i02 % p.ne12 ;
61+ const int64_t i13 = i03 % p.ne13 ;
3162
3263 x += int64_t (rowx)*ncols;
33- mask += int64_t (rowy)*ncols * (mask != nullptr );
64+ mask += (i11*p. nb11 + i12*p. nb12 + i13*p. nb13 ) / sizeof (T) * (mask != nullptr );
3465 dst += int64_t (rowx)*ncols;
3566
3667 const int block_size = block_size_template == 0 ? blockDim .x : block_size_template;
3768
3869 const int warp_id = threadIdx .x / WARP_SIZE;
3970 const int lane_id = threadIdx .x % WARP_SIZE;
4071
41- const float slope = get_alibi_slope (max_bias, rowx/nrows_y, n_head_log2, m0, m1);
72+ const float slope = get_alibi_slope (p. max_bias , i02, p. n_head_log2 , p. m0 , p. m1 );
4273
4374 extern __shared__ float data_soft_max_f32[];
4475 float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -60,7 +91,9 @@ static __global__ void soft_max_f32(
6091
6192 // const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
6293
63- const float val = x[col]*scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
94+ // const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
95+
96+ const float val = x[col]*p.scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
6497
6598 vals[col] = val;
6699 max_val = max (max_val, val);
@@ -156,63 +189,60 @@ static __global__ void soft_max_back_f32(
156189}
157190
158191template <typename T>
159- static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias , cudaStream_t stream) {
192+ static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const soft_max_params & params , cudaStream_t stream) {
160193 int nth = WARP_SIZE;
194+ const int64_t ncols_x = params.ncols ;
195+
161196 while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
162197 const dim3 block_dims (nth, 1 , 1 );
163- const dim3 block_nums (nrows_x, 1 , 1 );
198+ const dim3 block_nums (params. ne01 , params. ne02 , params. ne03 );
164199 const size_t nbytes_shared = (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof (float );
165200 static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
166201
167- const uint32_t n_head = nrows_x/nrows_y;
168- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
169-
170- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
171- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
172202
173203 // FIXME: this limit could be raised by ~2-4x on Ampere or newer
174204 if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
175205 switch (ncols_x) {
176206 case 32 :
177207 soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
178- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
208+ (x, mask, dst, params );
179209 break ;
180210 case 64 :
181211 soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
182- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
212+ (x, mask, dst, params );
183213 break ;
184214 case 128 :
185215 soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
186- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
216+ (x, mask, dst, params );
187217 break ;
188218 case 256 :
189219 soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
190- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
220+ (x, mask, dst, params );
191221 break ;
192222 case 512 :
193223 soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
194- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
224+ (x, mask, dst, params );
195225 break ;
196226 case 1024 :
197227 soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
198- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
228+ (x, mask, dst, params );
199229 break ;
200230 case 2048 :
201231 soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
202- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
232+ (x, mask, dst, params );
203233 break ;
204234 case 4096 :
205235 soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
206- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
236+ (x, mask, dst, params );
207237 break ;
208238 default :
209239 soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
210- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
240+ (x, mask, dst, params );
211241 break ;
212242 }
213243 } else {
214244 const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
215- soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
245+ soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params );
216246 }
217247}
218248
@@ -240,10 +270,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
240270
241271 GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
242272
243- const int64_t ne00 = src0->ne [0 ];
244273 const int64_t nrows_x = ggml_nrows (src0);
245274 const int64_t nrows_y = src0->ne [1 ];
246275
276+ const int64_t ne00 = src0->ne [0 ];
277+
247278 float scale = 1 .0f ;
248279 float max_bias = 0 .0f ;
249280
@@ -252,14 +283,56 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
252283
253284 const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
254285
286+ const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
287+ const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
288+ const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
289+
290+ const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
291+ const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
292+
293+ const uint32_t n_head = src0->ne [2 ];
294+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
295+
296+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
297+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
298+
299+
300+ soft_max_params params = {};
301+ params.nheads = src0->ne [2 ];
302+ params.n_head_log2 = n_head_log2;
303+ params.ncols = ne00;
304+ params.nrows_x = nrows_x;
305+ params.nrows_y = nrows_y;
306+ params.ne00 = src0->ne [0 ];
307+ params.ne01 = src0->ne [1 ];
308+ params.ne02 = src0->ne [2 ];
309+ params.ne03 = src0->ne [3 ];
310+ params.nb11 = nb11;
311+ params.nb12 = nb12;
312+ params.nb13 = nb13;
313+ params.ne12 = ne12;
314+ params.ne13 = ne13;
315+ params.scale = scale;
316+ params.max_bias = max_bias;
317+ params.m0 = m0;
318+ params.m1 = m1;
319+
255320 if (use_f16) {
321+
256322 // const half * src1_dd = (const half *)src1_d;
257323
258- soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
259- } else {
324+ // soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
325+
326+ // } else {
327+
260328 // const float * src1_dd = (const float *)src1_d;
261329
262- soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
330+ // soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
331+
332+ soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, params, stream);
333+ } else {
334+ soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, params, stream);
335+
263336 }
264337}
265338
0 commit comments