Skip to content

Commit f359216

Browse files
committed
CUDA: add dynamic shared mem to softmax, refactor general usage
1 parent de56944 commit f359216

File tree

5 files changed

+46
-24
lines changed

5 files changed

+46
-24
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,20 @@ static const char * cu_get_error_str(CUresult err) {
175175
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
176176
#endif
177177

178+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
179+
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
180+
do { \
181+
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
182+
const int id = ggml_cuda_get_device(); \
183+
if (!shared_memory_limit_raised[id]) { \
184+
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
185+
shared_memory_limit_raised[id] = true; \
186+
} \
187+
} while (0)
188+
#else
189+
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
190+
#endif
191+
178192
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
179193
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
180194
#else

ggml/src/ggml-cuda/cross-entropy-loss.cu

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
123123
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
124124

125125
if (nbytes_shared <= smpbo) {
126-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
127-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
128-
if (!shared_memory_limit_raised[id]) {
129-
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
130-
shared_memory_limit_raised[id] = true;
131-
}
132-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
126+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), nbytes_shared);
133127
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
134128
} else {
135129
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
@@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
175169
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
176170

177171
if (nbytes_shared <= smpbo) {
178-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
179-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
180-
if (!shared_memory_limit_raised[id]) {
181-
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
182-
shared_memory_limit_raised[id] = true;
183-
}
184-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
172+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), nbytes_shared);
185173
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
186174
} else {
187175
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3016,14 +3016,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
30163016

30173017
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
30183018

3019-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3020-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
3021-
if (!shared_memory_limit_raised[id]) {
3022-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3023-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3024-
shared_memory_limit_raised[id] = true;
3025-
}
3026-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3019+
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
3020+
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>), nbytes_shared);
30273021

30283022
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
30293023
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;

ggml/src/ggml-cuda/softmax.cu

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "ggml.h"
33
#include "softmax.cuh"
44
#include <cstdint>
5+
#include <utility>
56

67
template <typename T>
78
static __device__ __forceinline__ float t2f32(T val) {
@@ -150,6 +151,24 @@ static __global__ void soft_max_back_f32(
150151
}
151152
}
152153

154+
template<int... Ns>
155+
void increase_shared_mem_limits(std::size_t smpbo)
156+
{
157+
auto apply_limit = [smpbo](auto I) {
158+
constexpr int ncols = decltype(I)::value;
159+
constexpr int block = (ncols > 1024 ? 1024 : ncols);
160+
161+
CUDA_SET_SHARED_MEMORY_LIMIT(
162+
(soft_max_f32<true, ncols, block, half >), smpbo);
163+
CUDA_SET_SHARED_MEMORY_LIMIT(
164+
(soft_max_f32<true, ncols, block, float>), smpbo);
165+
};
166+
167+
//unary fold
168+
( apply_limit(std::integral_constant<int, Ns>{}), ... );
169+
}
170+
171+
153172
template<typename T>
154173
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
155174
int nth = WARP_SIZE;
@@ -165,8 +184,14 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
165184
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
166185
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
167186

168-
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
169-
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
187+
const int id = ggml_cuda_get_device();
188+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
189+
190+
191+
if (nbytes_shared <= smpbo) {
192+
193+
increase_shared_mem_limits<0, 32, 64, 128, 256, 512, 1024, 2048, 4096>(smpbo);
194+
170195
switch (ncols_x) {
171196
case 32:
172197
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4891,6 +4891,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
48914891
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
48924892

48934893
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4894+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 1024, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
48944895
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
48954896
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
48964897
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));

0 commit comments

Comments
 (0)