Skip to content

Commit 2e24304

Browse files
authored
[SYCL][CUDA] Add env variable to specify max local mem size (#6173)
Previously the max local mem allocation in DPC++ for CUDA backend was 48KB for most devices. However as https://docs.nvidia.com/cuda/ampere-tuning-guide/index.html notes, for the A100 the max local mem dynamic allocation is in fact 164KB. This PR introduces an environment variable SYCL_PI_CUDA_MAX_LOCAL_MEM_SZ which allows you to manually specify the max local memory in bytes allowed to be allocated per kernel for a given application. If an invalid value is specified (one that exceeds the device's capabilities/is negative) then a runtime error will be thrown. Using: SYCL_PI_CUDA_MAX_LOCAL_MEM_SZ=166912 ./a.out Allows the application to use up to 163KB of local memory, if the device supports it.
1 parent 3114f02 commit 2e24304

File tree

8 files changed

+63
-19
lines changed

8 files changed

+63
-19
lines changed

sycl/doc/EnvironmentVariables.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ Note that all device selectors will throw an exception if the filtered list of d
6969

7070
`(*) Note: Any means this environment variable is effective when set to any non-null value.`
7171

72+
## Controlling DPC++ CUDA Plugin
73+
74+
| Environment variable | Values | Description |
75+
| -------------------- | ------ | ----------- |
76+
| `SYCL_PI_CUDA_MAX_LOCAL_MEM_SIZE` | Integer | Specifies the maximum size of a local memory allocation in bytes. If the value exceeds the device's capabilities then a `sycl::runtime_error` is thrown. In order for the full error message to be printed, `SYCL_RT_WARNING_LEVEL=2` must be set. The default value for `SYCL_PI_CUDA_MAX_LOCAL_MEM_SIZE` is determined by the hardware. |
77+
7278
## Tools variables
7379

7480
| Environment variable | Values | Description |

sycl/include/CL/sycl/detail/common.hpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,19 +131,21 @@ static inline std::string codeToString(cl_int code) {
131131
#ifndef SYCL_SUPPRESS_EXCEPTIONS
132132
#include <CL/sycl/exception.hpp>
133133
// SYCL 1.2.1 exceptions
134-
#define __SYCL_REPORT_OCL_ERR_TO_EXC(expr, exc) \
134+
#define __SYCL_REPORT_OCL_ERR_TO_EXC(expr, exc, str) \
135135
{ \
136136
auto code = expr; \
137137
if (code != CL_SUCCESS) { \
138+
std::string err_str = \
139+
str ? "\n" + std::string(str) + "\n" : std::string{}; \
138140
throw exc(__SYCL_OCL_ERROR_REPORT + \
139-
cl::sycl::detail::codeToString(code), \
141+
cl::sycl::detail::codeToString(code) + err_str, \
140142
code); \
141143
} \
142144
}
143-
#define __SYCL_REPORT_OCL_ERR_TO_EXC_THROW(code, exc) \
144-
__SYCL_REPORT_OCL_ERR_TO_EXC(code, exc)
145+
#define __SYCL_REPORT_OCL_ERR_TO_EXC_THROW(code, exc, str) \
146+
__SYCL_REPORT_OCL_ERR_TO_EXC(code, exc, str)
145147
#define __SYCL_REPORT_OCL_ERR_TO_EXC_BASE(code) \
146-
__SYCL_REPORT_OCL_ERR_TO_EXC(code, cl::sycl::runtime_error)
148+
__SYCL_REPORT_OCL_ERR_TO_EXC(code, cl::sycl::runtime_error, nullptr)
147149
#else
148150
#define __SYCL_REPORT_OCL_ERR_TO_EXC_BASE(code) \
149151
__SYCL_REPORT_OCL_ERR_TO_STREAM(code)
@@ -164,15 +166,19 @@ static inline std::string codeToString(cl_int code) {
164166
#ifdef __SYCL_SUPPRESS_OCL_ERROR_REPORT
165167
// SYCL 1.2.1 exceptions
166168
#define __SYCL_CHECK_OCL_CODE(X) (void)(X)
167-
#define __SYCL_CHECK_OCL_CODE_THROW(X, EXC) (void)(X)
169+
#define __SYCL_CHECK_OCL_CODE_THROW(X, EXC, STR) \
170+
{ \
171+
(void)(X); \
172+
(void)(STR); \
173+
}
168174
#define __SYCL_CHECK_OCL_CODE_NO_EXC(X) (void)(X)
169175
// SYCL 2020 exceptions
170176
#define __SYCL_CHECK_CODE_THROW_VIA_ERRC(X, ERRC) (void)(X)
171177
#else
172178
// SYCL 1.2.1 exceptions
173179
#define __SYCL_CHECK_OCL_CODE(X) __SYCL_REPORT_OCL_ERR_TO_EXC_BASE(X)
174-
#define __SYCL_CHECK_OCL_CODE_THROW(X, EXC) \
175-
__SYCL_REPORT_OCL_ERR_TO_EXC_THROW(X, EXC)
180+
#define __SYCL_CHECK_OCL_CODE_THROW(X, EXC, STR) \
181+
__SYCL_REPORT_OCL_ERR_TO_EXC_THROW(X, EXC, STR)
176182
#define __SYCL_CHECK_OCL_CODE_NO_EXC(X) __SYCL_REPORT_OCL_ERR_TO_STREAM(X)
177183
// SYCL 2020 exceptions
178184
#define __SYCL_CHECK_CODE_THROW_VIA_ERRC(X, ERRC) \

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2867,6 +2867,28 @@ pi_result cuda_piEnqueueKernelLaunch(
28672867
retImplEv->start();
28682868
}
28692869

2870+
// Set local mem max size if env var is present
2871+
static const char *local_mem_sz_ptr =
2872+
std::getenv("SYCL_PI_CUDA_MAX_LOCAL_MEM_SIZE");
2873+
2874+
if (local_mem_sz_ptr) {
2875+
int device_max_local_mem = 0;
2876+
cuDeviceGetAttribute(
2877+
&device_max_local_mem,
2878+
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
2879+
command_queue->get_device()->get());
2880+
2881+
static const int env_val = std::atoi(local_mem_sz_ptr);
2882+
if (env_val <= 0 || env_val > device_max_local_mem) {
2883+
setErrorMessage("Invalid value specified for "
2884+
"SYCL_PI_CUDA_MAX_LOCAL_MEM_SIZE",
2885+
PI_PLUGIN_SPECIFIC_ERROR);
2886+
return PI_PLUGIN_SPECIFIC_ERROR;
2887+
}
2888+
PI_CHECK_ERROR(cuFuncSetAttribute(
2889+
cuFunc, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, env_val));
2890+
}
2891+
28702892
retError = PI_CHECK_ERROR(cuLaunchKernel(
28712893
cuFunc, blocksPerGrid[0], blocksPerGrid[1], blocksPerGrid[2],
28722894
threadsPerBlock[0], threadsPerBlock[1], threadsPerBlock[2], local_size,

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ struct _pi_queue {
442442

443443
_pi_context *get_context() const { return context_; };
444444

445+
_pi_device *get_device() const { return device_; };
446+
445447
pi_uint32 increment_reference_count() noexcept { return ++refCount_; }
446448

447449
pi_uint32 decrement_reference_count() noexcept { return --refCount_; }

sycl/source/detail/error_handling/enqueue_kernel.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace detail {
2222

2323
namespace enqueue_kernel_launch {
2424

25-
bool handleInvalidWorkGroupSize(const device_impl &DeviceImpl, pi_kernel Kernel,
25+
void handleInvalidWorkGroupSize(const device_impl &DeviceImpl, pi_kernel Kernel,
2626
const NDRDescT &NDRDesc) {
2727
const bool HasLocalSize = (NDRDesc.LocalSize[0] != 0);
2828

@@ -246,7 +246,7 @@ bool handleInvalidWorkGroupSize(const device_impl &DeviceImpl, pi_kernel Kernel,
246246
"PI backend failed. PI backend returns: " + codeToString(Error), Error);
247247
}
248248

249-
bool handleInvalidWorkItemSize(const device_impl &DeviceImpl,
249+
void handleInvalidWorkItemSize(const device_impl &DeviceImpl,
250250
const NDRDescT &NDRDesc) {
251251

252252
const plugin &Plugin = DeviceImpl.getPlugin();
@@ -265,10 +265,9 @@ bool handleInvalidWorkItemSize(const device_impl &DeviceImpl,
265265
" > " + std::to_string(MaxWISize[I]),
266266
PI_INVALID_WORK_ITEM_SIZE);
267267
}
268-
return 0;
269268
}
270269

271-
bool handleInvalidValue(const device_impl &DeviceImpl,
270+
void handleInvalidValue(const device_impl &DeviceImpl,
272271
const NDRDescT &NDRDesc) {
273272
const plugin &Plugin = DeviceImpl.getPlugin();
274273
RT::PiDevice Device = DeviceImpl.getHandleRef();
@@ -293,8 +292,8 @@ bool handleInvalidValue(const device_impl &DeviceImpl,
293292
"Native API failed. Native API returns: " + codeToString(Error), Error);
294293
}
295294

296-
bool handleError(pi_result Error, const device_impl &DeviceImpl,
297-
pi_kernel Kernel, const NDRDescT &NDRDesc) {
295+
void handleErrorOrWarning(pi_result Error, const device_impl &DeviceImpl,
296+
pi_kernel Kernel, const NDRDescT &NDRDesc) {
298297
assert(Error != PI_SUCCESS &&
299298
"Success is expected to be handled on caller side");
300299
switch (Error) {
@@ -343,6 +342,14 @@ bool handleError(pi_result Error, const device_impl &DeviceImpl,
343342
case PI_INVALID_VALUE:
344343
return handleInvalidValue(DeviceImpl, NDRDesc);
345344

345+
case PI_PLUGIN_SPECIFIC_ERROR:
346+
// checkPiResult does all the necessary handling for
347+
// PI_PLUGIN_SPECIFIC_ERROR, making sure an error is thrown or not,
348+
// depending on whether PI_PLUGIN_SPECIFIC_ERROR contains an error or a
349+
// warning. It also ensures that the contents of the error message buffer
350+
// (used only by PI_PLUGIN_SPECIFIC_ERROR) get handled correctly.
351+
return DeviceImpl.getPlugin().checkPiResult(Error);
352+
346353
// TODO: Handle other error codes
347354

348355
default:

sycl/source/detail/error_handling/error_handling.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ namespace enqueue_kernel_launch {
2525
///
2626
/// This function actually never returns and always throws an exception with
2727
/// error description.
28-
bool handleError(pi_result, const device_impl &, pi_kernel, const NDRDescT &);
28+
void handleErrorOrWarning(pi_result, const device_impl &, pi_kernel,
29+
const NDRDescT &);
2930
} // namespace enqueue_kernel_launch
3031

3132
} // namespace detail

sycl/source/detail/plugin.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ class plugin {
114114
/// \throw Exception if pi_result is not a PI_SUCCESS.
115115
template <typename Exception = cl::sycl::runtime_error>
116116
void checkPiResult(RT::PiResult pi_result) const {
117+
char *message = nullptr;
117118
if (pi_result == PI_PLUGIN_SPECIFIC_ERROR) {
118-
char *message = nullptr;
119119
pi_result = call_nocheck<PiApiKind::piPluginGetLastError>(&message);
120120

121121
// If the warning level is greater then 2 emit the message
@@ -126,7 +126,7 @@ class plugin {
126126
if (pi_result == PI_SUCCESS)
127127
return;
128128
}
129-
__SYCL_CHECK_OCL_CODE_THROW(pi_result, Exception);
129+
__SYCL_CHECK_OCL_CODE_THROW(pi_result, Exception, message);
130130
}
131131

132132
/// \throw SYCL 2020 exception(errc) if pi_result is not PI_SUCCESS

sycl/source/detail/scheduler/commands.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2185,8 +2185,8 @@ cl_int enqueueImpKernel(
21852185
// If we have got non-success error code, let's analyze it to emit nice
21862186
// exception explaining what was wrong
21872187
const device_impl &DeviceImpl = *(Queue->getDeviceImplPtr());
2188-
return detail::enqueue_kernel_launch::handleError(Error, DeviceImpl, Kernel,
2189-
NDRDesc);
2188+
detail::enqueue_kernel_launch::handleErrorOrWarning(Error, DeviceImpl,
2189+
Kernel, NDRDesc);
21902190
}
21912191

21922192
return PI_SUCCESS;

0 commit comments

Comments
 (0)