Skip to content

Commit a3542ed

Browse files
authored
[SYCL][CUDA] Decouple CUDA contexts from PI contexts (#8197)
This patch moves the CUDA context from the PI context to the PI device, and switches to always using the primary context. CUDA contexts are different from SYCL contexts in that they're tied to a single device, and that they are required to be active on a thread for most calls to the CUDA driver API. As shown in #8124 and #7526 the current mapping of CUDA context to PI context, causes issues for device based entry points that still need to call the CUDA APIs, we have workarounds to solve that but they're a bit hacky, inefficient, and have a lot of edge case issues. The peer to peer interface proposal in #6104, is also device based, but enabling peer to peer for CUDA is done on the CUDA contexts, so the current mapping would make it difficult to implement. So this patch solves most of these issues by decoupling the CUDA context from the SYCL context, and simply managing the CUDA contexts in the devices, it also changes the CUDA context management to always use the primary context. This approach as a number of advantages: * Use of the primary context is recommended by Nvidia * Simplifies the CUDA context management in the plugin * Available CUDA context in device based entry points * Likely more efficient in the general case, with less opportunities to accidentally cause costly CUDA context switches. * Easier and likely more efficient interactions with CUDA runtime applications. * Easier to expose P2P capabilities * Easier to support multiple devices in a SYCL context It does have a few drawbacks from the previous approach: * Drops support for `make_context` interop, no sensible "native handle" to pass in (`get_native` is still supported fine). * No opportunity for users to separate their work into different CUDA contexts. It's unclear if there's any actual use case for this, it seems very uncommon in CUDA codebases to have multiple CUDA contexts for a single CUDA device in the same process. So overall I believe this should be a net benefit in general, and we could revisit if we run into an edge case that would need more fine grained CUDA context management.
1 parent fa8ce20 commit a3542ed

File tree

7 files changed

+28
-210
lines changed

7 files changed

+28
-210
lines changed

sycl/include/sycl/detail/properties_traits.def

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ __SYCL_PARAM_TRAITS_SPEC(sycl::property::no_init)
1111
__SYCL_PARAM_TRAITS_SPEC(
1212
sycl::property::context::cuda::use_primary_context) // Deprecated
1313
__SYCL_PARAM_TRAITS_SPEC(
14-
sycl::ext::oneapi::cuda::property::context::use_primary_context)
14+
sycl::ext::oneapi::cuda::property::context::use_primary_context) // Deprecated
1515
__SYCL_PARAM_TRAITS_SPEC(sycl::property::queue::in_order)
1616
__SYCL_PARAM_TRAITS_SPEC(sycl::property::reduction::initialize_to_identity)
1717
__SYCL_PARAM_TRAITS_SPEC(sycl::ext::oneapi::property::queue::priority_low)
1818
__SYCL_PARAM_TRAITS_SPEC(sycl::ext::oneapi::property::queue::priority_high)
19-
__SYCL_PARAM_TRAITS_SPEC(sycl::ext::oneapi::property::queue::priority_normal)
19+
__SYCL_PARAM_TRAITS_SPEC(sycl::ext::oneapi::property::queue::priority_normal)

sycl/include/sycl/ext/oneapi/experimental/backend/backend_traits_cuda.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ template <> struct BackendReturn<backend::ext_oneapi_cuda, platform> {
114114
template <> struct InteropFeatureSupportMap<backend::ext_oneapi_cuda> {
115115
static constexpr bool MakePlatform = false;
116116
static constexpr bool MakeDevice = true;
117-
static constexpr bool MakeContext = true;
117+
static constexpr bool MakeContext = false;
118118
static constexpr bool MakeQueue = true;
119119
static constexpr bool MakeEvent = true;
120120
static constexpr bool MakeBuffer = false;

sycl/include/sycl/properties/context_properties.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
namespace sycl {
1616
__SYCL_INLINE_VER_NAMESPACE(_V1) {
1717
namespace ext::oneapi::cuda::property::context {
18-
class use_primary_context : public ::sycl::detail::DataLessProperty<
19-
::sycl::detail::UsePrimaryContext> {};
18+
class __SYCL_DEPRECATED("the primary contexts are now always used")
19+
use_primary_context : public ::sycl::detail::DataLessProperty<
20+
::sycl::detail::UsePrimaryContext> {};
2021
} // namespace ext::oneapi::cuda::property::context
2122

2223
namespace property::context {

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 12 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -917,8 +917,11 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms,
917917
for (int i = 0; i < numDevices; ++i) {
918918
CUdevice device;
919919
err = PI_CHECK_ERROR(cuDeviceGet(&device, i));
920+
CUcontext context;
921+
err = PI_CHECK_ERROR(cuDevicePrimaryCtxRetain(&context, device));
922+
920923
platformIds[i].devices_.emplace_back(
921-
new _pi_device{device, &platformIds[i]});
924+
new _pi_device{device, context, &platformIds[i]});
922925

923926
{
924927
const auto &dev = platformIds[i].devices_.back().get();
@@ -1183,6 +1186,8 @@ pi_result cuda_piDeviceGetInfo(pi_device device, pi_device_info param_name,
11831186

11841187
assert(device != nullptr);
11851188

1189+
ScopedContext active(device->get_context());
1190+
11861191
switch (param_name) {
11871192
case PI_DEVICE_INFO_TYPE: {
11881193
return getInfo(param_value_size, param_value, param_value_size_ret,
@@ -1961,7 +1966,6 @@ pi_result cuda_piDeviceGetInfo(pi_device device, pi_device_info param_name,
19611966
}
19621967

19631968
case PI_EXT_INTEL_DEVICE_INFO_FREE_MEMORY: {
1964-
ScopedContext active(device);
19651969
size_t FreeMemory = 0;
19661970
size_t TotalMemory = 0;
19671971
sycl::detail::pi::assertion(cuMemGetInfo(&FreeMemory, &TotalMemory) ==
@@ -2121,50 +2125,10 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties,
21212125
assert(retcontext != nullptr);
21222126
pi_result errcode_ret = PI_SUCCESS;
21232127

2124-
// Parse properties.
2125-
bool property_cuda_primary = false;
2126-
while (properties && (0 != *properties)) {
2127-
// Consume property ID.
2128-
pi_context_properties id = *properties;
2129-
++properties;
2130-
// Consume property value.
2131-
pi_context_properties value = *properties;
2132-
++properties;
2133-
switch (id) {
2134-
case __SYCL_PI_CONTEXT_PROPERTIES_CUDA_PRIMARY:
2135-
assert(value == PI_FALSE || value == PI_TRUE);
2136-
property_cuda_primary = static_cast<bool>(value);
2137-
break;
2138-
default:
2139-
// Unknown property.
2140-
sycl::detail::pi::die(
2141-
"Unknown piContextCreate property in property list");
2142-
return PI_ERROR_INVALID_VALUE;
2143-
}
2144-
}
2145-
21462128
std::unique_ptr<_pi_context> piContextPtr{nullptr};
21472129
try {
2148-
CUcontext current = nullptr;
2149-
2150-
if (property_cuda_primary) {
2151-
// Use the CUDA primary context and assume that we want to use it
2152-
// immediately as we want to forge context switches.
2153-
CUcontext Ctxt;
2154-
errcode_ret =
2155-
PI_CHECK_ERROR(cuDevicePrimaryCtxRetain(&Ctxt, devices[0]->get()));
2156-
piContextPtr = std::unique_ptr<_pi_context>(
2157-
new _pi_context{_pi_context::kind::primary, Ctxt, *devices});
2158-
errcode_ret = PI_CHECK_ERROR(cuCtxPushCurrent(Ctxt));
2159-
} else {
2160-
// Create a scoped context.
2161-
CUcontext newContext;
2162-
PI_CHECK_ERROR(cuCtxGetCurrent(&current));
2163-
errcode_ret = PI_CHECK_ERROR(
2164-
cuCtxCreate(&newContext, CU_CTX_MAP_HOST, devices[0]->get()));
2165-
piContextPtr = std::unique_ptr<_pi_context>(new _pi_context{
2166-
_pi_context::kind::user_defined, newContext, *devices});
2167-
}
2130+
piContextPtr = std::unique_ptr<_pi_context>(new _pi_context{*devices});
2131+
21682132
static std::once_flag initFlag;
21692133
std::call_once(
21702134
initFlag,
@@ -2176,14 +2140,6 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties,
21762140
},
21772141
errcode_ret);
21782142

2179-
// For non-primary scoped contexts keep the last active on top of the stack
2180-
// as `cuCtxCreate` replaces it implicitly otherwise.
2181-
// Primary contexts are kept on top of the stack, so the previous context
2182-
// is not queried and therefore not recovered.
2183-
if (current != nullptr) {
2184-
PI_CHECK_ERROR(cuCtxSetCurrent(current));
2185-
}
2186-
21872143
*retcontext = piContextPtr.release();
21882144
} catch (pi_result err) {
21892145
errcode_ret = err;
@@ -2194,7 +2150,6 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties,
21942150
}
21952151

21962152
pi_result cuda_piContextRelease(pi_context ctxt) {
2197-
21982153
assert(ctxt != nullptr);
21992154

22002155
if (ctxt->decrement_reference_count() > 0) {
@@ -2204,29 +2159,7 @@ pi_result cuda_piContextRelease(pi_context ctxt) {
22042159

22052160
std::unique_ptr<_pi_context> context{ctxt};
22062161

2207-
if (!ctxt->backend_has_ownership())
2208-
return PI_SUCCESS;
2209-
2210-
if (!ctxt->is_primary()) {
2211-
CUcontext cuCtxt = ctxt->get();
2212-
CUcontext current = nullptr;
2213-
cuCtxGetCurrent(&current);
2214-
if (cuCtxt != current) {
2215-
PI_CHECK_ERROR(cuCtxPushCurrent(cuCtxt));
2216-
}
2217-
PI_CHECK_ERROR(cuCtxSynchronize());
2218-
cuCtxGetCurrent(&current);
2219-
if (cuCtxt == current) {
2220-
PI_CHECK_ERROR(cuCtxPopCurrent(&current));
2221-
}
2222-
return PI_CHECK_ERROR(cuCtxDestroy(cuCtxt));
2223-
}
2224-
2225-
// Primary context is not destroyed, but released
2226-
CUdevice cuDev = ctxt->get_device()->get();
2227-
CUcontext current;
2228-
cuCtxPopCurrent(&current);
2229-
return PI_CHECK_ERROR(cuDevicePrimaryCtxRelease(cuDev));
2162+
return PI_SUCCESS;
22302163
}
22312164

22322165
/// Gets the native CUDA handle of a PI context object
@@ -2253,29 +2186,15 @@ pi_result cuda_piextContextCreateWithNativeHandle(pi_native_handle nativeHandle,
22532186
const pi_device *devices,
22542187
bool ownNativeHandle,
22552188
pi_context *piContext) {
2189+
(void)nativeHandle;
22562190
(void)num_devices;
22572191
(void)devices;
22582192
(void)ownNativeHandle;
2193+
(void)piContext;
22592194
assert(piContext != nullptr);
22602195
assert(ownNativeHandle == false);
22612196

2262-
CUcontext newContext = reinterpret_cast<CUcontext>(nativeHandle);
2263-
2264-
ScopedContext active(newContext);
2265-
2266-
// Get context's native device
2267-
CUdevice cu_device;
2268-
pi_result retErr = PI_CHECK_ERROR(cuCtxGetDevice(&cu_device));
2269-
2270-
// Create a SYCL device from the ctx device
2271-
pi_device device = nullptr;
2272-
retErr = cuda_piextDeviceCreateWithNativeHandle(cu_device, nullptr, &device);
2273-
2274-
// Create sycl context
2275-
*piContext = new _pi_context{_pi_context::kind::user_defined, newContext,
2276-
device, /*backend_owns*/ false};
2277-
2278-
return retErr;
2197+
return PI_ERROR_INVALID_OPERATION;
22792198
}
22802199

22812200
/// Creates a PI Memory object using a CUDA memory allocation.
@@ -2469,8 +2388,6 @@ pi_result cuda_piMemBufferPartition(pi_mem parent_buffer, pi_mem_flags flags,
24692388

24702389
std::unique_ptr<_pi_mem> retMemObj{nullptr};
24712390
try {
2472-
ScopedContext active(context);
2473-
24742391
retMemObj = std::unique_ptr<_pi_mem>{new _pi_mem{
24752392
context, parent_buffer, allocMode, ptr, hostPtr, bufferRegion.size}};
24762393
} catch (pi_result err) {

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,28 +86,29 @@ struct _pi_device {
8686
using native_type = CUdevice;
8787

8888
native_type cuDevice_;
89+
CUcontext cuContext_;
8990
std::atomic_uint32_t refCount_;
9091
pi_platform platform_;
91-
pi_context context_;
9292

9393
static constexpr pi_uint32 max_work_item_dimensions = 3u;
9494
size_t max_work_item_sizes[max_work_item_dimensions];
9595
int max_work_group_size;
9696

9797
public:
98-
_pi_device(native_type cuDevice, pi_platform platform)
99-
: cuDevice_(cuDevice), refCount_{1}, platform_(platform) {}
98+
_pi_device(native_type cuDevice, CUcontext cuContext, pi_platform platform)
99+
: cuDevice_(cuDevice), cuContext_(cuContext), refCount_{1},
100+
platform_(platform) {}
101+
102+
~_pi_device() { cuDevicePrimaryCtxRelease(cuDevice_); }
100103

101104
native_type get() const noexcept { return cuDevice_; };
102105

106+
CUcontext get_context() const noexcept { return cuContext_; };
107+
103108
pi_uint32 get_reference_count() const noexcept { return refCount_; }
104109

105110
pi_platform get_platform() const noexcept { return platform_; };
106111

107-
void set_context(pi_context ctx) { context_ = ctx; };
108-
109-
pi_context get_context() { return context_; };
110-
111112
void save_max_work_item_sizes(size_t size,
112113
size_t *save_max_work_item_sizes) noexcept {
113114
memcpy(max_work_item_sizes, save_max_work_item_sizes, size);
@@ -174,16 +175,12 @@ struct _pi_context {
174175

175176
using native_type = CUcontext;
176177

177-
enum class kind { primary, user_defined } kind_;
178178
native_type cuContext_;
179179
_pi_device *deviceId_;
180180
std::atomic_uint32_t refCount_;
181181

182-
_pi_context(kind k, CUcontext ctxt, _pi_device *devId,
183-
bool backend_owns = true)
184-
: kind_{k}, cuContext_{ctxt}, deviceId_{devId}, refCount_{1},
185-
has_ownership{backend_owns} {
186-
deviceId_->set_context(this);
182+
_pi_context(_pi_device *devId)
183+
: cuContext_{devId->get_context()}, deviceId_{devId}, refCount_{1} {
187184
cuda_piDeviceRetain(deviceId_);
188185
};
189186

@@ -206,20 +203,15 @@ struct _pi_context {
206203

207204
native_type get() const noexcept { return cuContext_; }
208205

209-
bool is_primary() const noexcept { return kind_ == kind::primary; }
210-
211206
pi_uint32 increment_reference_count() noexcept { return ++refCount_; }
212207

213208
pi_uint32 decrement_reference_count() noexcept { return --refCount_; }
214209

215210
pi_uint32 get_reference_count() const noexcept { return refCount_; }
216211

217-
bool backend_has_ownership() const noexcept { return has_ownership; }
218-
219212
private:
220213
std::mutex mutex_;
221214
std::vector<deleter_data> extended_deleters_;
222-
const bool has_ownership;
223215
};
224216

225217
/// PI Mem mapping to CUDA memory allocations, both data and texture/surface.

sycl/test/basic_tests/interop-cuda.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ int main() {
8787

8888
backend_input_t<backend::ext_oneapi_cuda, context> InteropContextInput{
8989
cu_context[0]};
90-
context InteropContext =
91-
make_context<backend::ext_oneapi_cuda>(InteropContextInput);
9290
event InteropEvent = make_event<backend::ext_oneapi_cuda>(cu_event, Context);
9391

9492
queue InteropQueue = make_queue<backend::ext_oneapi_cuda>(cu_queue, Context);

sycl/unittests/pi/cuda/test_base_objects.cpp

Lines changed: 0 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -79,96 +79,6 @@ TEST_F(CudaBaseObjectsTest, piContextCreate) {
7979
cuCtxGetApiVersion(cudaContext, &version);
8080
EXPECT_EQ(version, LATEST_KNOWN_CUDA_DRIVER_API_VERSION);
8181

82-
CUresult cuErr = cuCtxDestroy(cudaContext);
83-
ASSERT_EQ(cuErr, CUDA_SUCCESS);
84-
}
85-
86-
TEST_F(CudaBaseObjectsTest, piContextCreatePrimaryTrue) {
87-
pi_uint32 numPlatforms = 0;
88-
pi_platform platform;
89-
pi_device device;
90-
91-
ASSERT_EQ((plugin->call_nocheck<detail::PiApiKind::piPlatformsGet>(
92-
0, nullptr, &numPlatforms)),
93-
PI_SUCCESS)
94-
<< "piPlatformsGet failed.\n";
95-
96-
ASSERT_EQ((plugin->call_nocheck<detail::PiApiKind::piPlatformsGet>(
97-
numPlatforms, &platform, nullptr)),
98-
PI_SUCCESS)
99-
<< "piPlatformsGet failed.\n";
100-
101-
ASSERT_EQ((plugin->call_nocheck<detail::PiApiKind::piDevicesGet>(
102-
platform, PI_DEVICE_TYPE_GPU, 1, &device, nullptr)),
103-
PI_SUCCESS);
104-
pi_context_properties properties[] = {
105-
__SYCL_PI_CONTEXT_PROPERTIES_CUDA_PRIMARY, PI_TRUE, 0};
106-
107-
pi_context ctxt;
108-
ASSERT_EQ((plugin->call_nocheck<detail::PiApiKind::piContextCreate>(
109-
properties, 1, &device, nullptr, nullptr, &ctxt)),
110-
PI_SUCCESS);
111-
EXPECT_NE(ctxt, nullptr);
112-
EXPECT_EQ(ctxt->get_device(), device);
113-
EXPECT_TRUE(ctxt->is_primary());
114-
115-
// Retrieve the cuCtxt to check information is correct
116-
CUcontext cudaContext = ctxt->get();
117-
unsigned int version = 0;
118-
CUresult cuErr = cuCtxGetApiVersion(cudaContext, &version);
119-
ASSERT_EQ(cuErr, CUDA_SUCCESS);
120-
EXPECT_EQ(version, LATEST_KNOWN_CUDA_DRIVER_API_VERSION);
121-
122-
// Current context in the stack?
123-
CUcontext current;
124-
cuErr = cuCtxGetCurrent(&current);
125-
ASSERT_EQ(cuErr, CUDA_SUCCESS);
126-
ASSERT_EQ(current, cudaContext);
127-
ASSERT_EQ((plugin->call_nocheck<detail::PiApiKind::piContextRelease>(ctxt)),
128-
PI_SUCCESS);
129-
}
130-
131-
TEST_F(CudaBaseObjectsTest, piContextCreatePrimaryFalse) {
132-
pi_uint32 numPlatforms = 0;
133-
pi_platform platform;
134-
pi_device device;
135-
136-
ASSERT_EQ((plugin->call_nocheck<detail::PiApiKind::piPlatformsGet>(
137-
0, nullptr, &numPlatforms)),
138-
PI_SUCCESS)
139-
<< "piPlatformsGet failed.\n";
140-
141-
ASSERT_EQ((plugin->call_nocheck<detail::PiApiKind::piPlatformsGet>(
142-
numPlatforms, &platform, nullptr)),
143-
PI_SUCCESS)
144-
<< "piPlatformsGet failed.\n";
145-
146-
ASSERT_EQ((plugin->call_nocheck<detail::PiApiKind::piDevicesGet>(
147-
platform, PI_DEVICE_TYPE_GPU, 1, &device, nullptr)),
148-
PI_SUCCESS);
149-
pi_context_properties properties[] = {
150-
__SYCL_PI_CONTEXT_PROPERTIES_CUDA_PRIMARY, PI_FALSE, 0};
151-
152-
pi_context ctxt;
153-
ASSERT_EQ((plugin->call_nocheck<detail::PiApiKind::piContextCreate>(
154-
properties, 1, &device, nullptr, nullptr, &ctxt)),
155-
PI_SUCCESS);
156-
EXPECT_NE(ctxt, nullptr);
157-
EXPECT_EQ(ctxt->get_device(), device);
158-
EXPECT_FALSE(ctxt->is_primary());
159-
160-
// Retrieve the cuCtxt to check information is correct
161-
CUcontext cudaContext = ctxt->get();
162-
unsigned int version = 0;
163-
CUresult cuErr = cuCtxGetApiVersion(cudaContext, &version);
164-
ASSERT_EQ(cuErr, CUDA_SUCCESS);
165-
EXPECT_EQ(version, LATEST_KNOWN_CUDA_DRIVER_API_VERSION);
166-
167-
// Current context in the stack?
168-
CUcontext current;
169-
cuErr = cuCtxGetCurrent(&current);
170-
ASSERT_EQ(cuErr, CUDA_SUCCESS);
171-
ASSERT_EQ(current, cudaContext);
17282
ASSERT_EQ((plugin->call_nocheck<detail::PiApiKind::piContextRelease>(ctxt)),
17383
PI_SUCCESS);
17484
}

0 commit comments

Comments
 (0)