22#include " ggml.h"
33#include " softmax.cuh"
44#include < cstdint>
5- #include < utility>
65
76template <typename T>
87static __device__ __forceinline__ float t2f32 (T val) {
@@ -182,37 +181,6 @@ static __global__ void soft_max_back_f32(
182181 }
183182}
184183
185- template <int ... Ns, typename T>
186- static void launch_soft_max_kernels (const float * x, const T * mask, float * dst,
187- const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
188- {
189- const int id = ggml_cuda_get_device ();
190- const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
191-
192- auto launch_kernel = [=](auto I) -> bool {
193- constexpr int ncols = decltype (I)::value;
194- constexpr int block = (ncols > 1024 ? 1024 : ncols);
195-
196- if (p.ncols == ncols) {
197- CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , ncols, block, T>), smpbo);
198- soft_max_f32<true , ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
199- (x, mask, dst, p);
200- return true ;
201- }
202- return false ;
203- };
204-
205- // unary fold over launch_kernel
206- if ((launch_kernel (std::integral_constant<int , Ns>{}) || ...)) {
207- return ;
208- }
209-
210- // default case
211- CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , 0 , 0 , T>), smpbo);
212- soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>> (x, mask, dst, p);
213- }
214-
215-
216184template <typename T>
217185static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
218186 int nth = WARP_SIZE;
@@ -225,12 +193,46 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
225193 static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
226194
227195
228- const int id = ggml_cuda_get_device ();
229- const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
230-
231-
232- if (nbytes_shared <= smpbo) {
233- launch_soft_max_kernels<32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
196+ // FIXME: this limit could be raised by ~2-4x on Ampere or newer
197+ if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
198+ switch (ncols_x) {
199+ case 32 :
200+ soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
201+ (x, mask, dst, params);
202+ break ;
203+ case 64 :
204+ soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
205+ (x, mask, dst, params);
206+ break ;
207+ case 128 :
208+ soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
209+ (x, mask, dst, params);
210+ break ;
211+ case 256 :
212+ soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
213+ (x, mask, dst, params);
214+ break ;
215+ case 512 :
216+ soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
217+ (x, mask, dst, params);
218+ break ;
219+ case 1024 :
220+ soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
221+ (x, mask, dst, params);
222+ break ;
223+ case 2048 :
224+ soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
225+ (x, mask, dst, params);
226+ break ;
227+ case 4096 :
228+ soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
229+ (x, mask, dst, params);
230+ break ;
231+ default :
232+ soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
233+ (x, mask, dst, params);
234+ break ;
235+ }
234236 } else {
235237 const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
236238 soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params);
0 commit comments