@@ -151,21 +151,37 @@ static __global__ void soft_max_back_f32(
151
151
}
152
152
}
153
153
154
- template <int ... Ns>
155
- void increase_shared_mem_limits (std::size_t smpbo)
154
+ template <int ... Ns, typename T>
155
+ static void launch_soft_max_kernels (int ncols_x, const float * x, const T * mask, float * dst,
156
+ int ncols_param, int nrows_y, float scale, float max_bias,
157
+ float m0, float m1, uint32_t n_head_log2, dim3 block_nums,
158
+ dim3 block_dims, size_t nbytes_shared, cudaStream_t stream)
156
159
{
157
- auto apply_limit = [smpbo](auto I) {
158
- constexpr int ncols = decltype (I)::value;
159
- constexpr int block = (ncols > 1024 ? 1024 : ncols);
160
-
161
- CUDA_SET_SHARED_MEMORY_LIMIT (
162
- (soft_max_f32<true , ncols, block, half >), smpbo);
163
- CUDA_SET_SHARED_MEMORY_LIMIT (
164
- (soft_max_f32<true , ncols, block, float >), smpbo);
160
+ const int id = ggml_cuda_get_device ();
161
+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
162
+
163
+ auto launch_kernel = [=](auto I) -> bool {
164
+ constexpr int ncols = decltype (I)::value;
165
+ constexpr int block = (ncols > 1024 ? 1024 : ncols);
166
+
167
+ if (ncols_x == ncols) {
168
+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , ncols, block, T>), smpbo);
169
+ soft_max_f32<true , ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
170
+ (x, mask, dst, ncols_param, nrows_y, scale, max_bias, m0, m1, n_head_log2);
171
+ return true ;
172
+ }
173
+ return false ;
165
174
};
166
175
167
- // unary fold
168
- ( apply_limit (std::integral_constant<int , Ns>{}), ... );
176
+ // unary fold over launch_kernel
177
+ if ((launch_kernel (std::integral_constant<int , Ns>{}) || ...)) {
178
+ return ;
179
+ }
180
+
181
+ // default case
182
+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , 0 , 0 , T>), smpbo);
183
+ soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
184
+ (x, mask, dst, ncols_param, nrows_y, scale, max_bias, m0, m1, n_head_log2);
169
185
}
170
186
171
187
@@ -189,47 +205,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
189
205
190
206
191
207
if (nbytes_shared <= smpbo) {
192
-
193
- increase_shared_mem_limits<0 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(smpbo);
194
-
195
- switch (ncols_x) {
196
- case 32 :
197
- soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
198
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
199
- break ;
200
- case 64 :
201
- soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
202
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
203
- break ;
204
- case 128 :
205
- soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
206
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
207
- break ;
208
- case 256 :
209
- soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
210
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
211
- break ;
212
- case 512 :
213
- soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
214
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
215
- break ;
216
- case 1024 :
217
- soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
218
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
219
- break ;
220
- case 2048 :
221
- soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
222
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
223
- break ;
224
- case 4096 :
225
- soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
226
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
227
- break ;
228
- default :
229
- soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
230
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
231
- break ;
232
- }
208
+ launch_soft_max_kernels<32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(
209
+ ncols_x, x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, nbytes_shared, stream);
233
210
} else {
234
211
const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
235
212
soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
0 commit comments