Skip to content

Commit 4c7bcaa

Browse files
committed
Review: refactor switch statement, change cross_entropy to use full size
1 parent f359216 commit 4c7bcaa

File tree

3 files changed

+32
-55
lines changed

3 files changed

+32
-55
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ static const char * cu_get_error_str(CUresult err) {
187187
} while (0)
188188
#else
189189
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
190-
#endif
190+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
191191

192192
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
193193
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)

ggml/src/ggml-cuda/cross-entropy-loss.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
169169
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
170170

171171
if (nbytes_shared <= smpbo) {
172-
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), nbytes_shared);
172+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
173173
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
174174
} else {
175175
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);

ggml/src/ggml-cuda/softmax.cu

Lines changed: 30 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -151,21 +151,37 @@ static __global__ void soft_max_back_f32(
151151
}
152152
}
153153

154-
template<int... Ns>
155-
void increase_shared_mem_limits(std::size_t smpbo)
154+
template<int... Ns, typename T>
155+
static void launch_soft_max_kernels(int ncols_x, const float * x, const T * mask, float * dst,
156+
int ncols_param, int nrows_y, float scale, float max_bias,
157+
float m0, float m1, uint32_t n_head_log2, dim3 block_nums,
158+
dim3 block_dims, size_t nbytes_shared, cudaStream_t stream)
156159
{
157-
auto apply_limit = [smpbo](auto I) {
158-
constexpr int ncols = decltype(I)::value;
159-
constexpr int block = (ncols > 1024 ? 1024 : ncols);
160-
161-
CUDA_SET_SHARED_MEMORY_LIMIT(
162-
(soft_max_f32<true, ncols, block, half >), smpbo);
163-
CUDA_SET_SHARED_MEMORY_LIMIT(
164-
(soft_max_f32<true, ncols, block, float>), smpbo);
160+
const int id = ggml_cuda_get_device();
161+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
162+
163+
auto launch_kernel = [=](auto I) -> bool {
164+
constexpr int ncols = decltype(I)::value;
165+
constexpr int block = (ncols > 1024 ? 1024 : ncols);
166+
167+
if (ncols_x == ncols) {
168+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
169+
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
170+
(x, mask, dst, ncols_param, nrows_y, scale, max_bias, m0, m1, n_head_log2);
171+
return true;
172+
}
173+
return false;
165174
};
166175

167-
//unary fold
168-
( apply_limit(std::integral_constant<int, Ns>{}), ... );
176+
// unary fold over launch_kernel
177+
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
178+
return;
179+
}
180+
181+
//default case
182+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
183+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
184+
(x, mask, dst, ncols_param, nrows_y, scale, max_bias, m0, m1, n_head_log2);
169185
}
170186

171187

@@ -189,47 +205,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
189205

190206

191207
if (nbytes_shared <= smpbo) {
192-
193-
increase_shared_mem_limits<0, 32, 64, 128, 256, 512, 1024, 2048, 4096>(smpbo);
194-
195-
switch (ncols_x) {
196-
case 32:
197-
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
198-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
199-
break;
200-
case 64:
201-
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
202-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
203-
break;
204-
case 128:
205-
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
206-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
207-
break;
208-
case 256:
209-
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
210-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
211-
break;
212-
case 512:
213-
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
214-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
215-
break;
216-
case 1024:
217-
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
218-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
219-
break;
220-
case 2048:
221-
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
222-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
223-
break;
224-
case 4096:
225-
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
226-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
227-
break;
228-
default:
229-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
230-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
231-
break;
232-
}
208+
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(
209+
ncols_x, x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, nbytes_shared, stream);
233210
} else {
234211
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
235212
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);

0 commit comments

Comments
 (0)