Skip to content

Commit 82ea16d

Browse files
author
Diptorup Deb
authored
Remove workarounds to make device equality work. (#338)
- dpcpp 2021.2 fixed device equality in the dpcpp runtime and we can now remove several workarounds in dpctl: - DPCTLDeviceMgr_GetDeviceAndContextPair is now renamed as DPCTLDeviceMgr_GetCachedContext and only returns a DPCTLSyclContextRef instead of a pair of DPCTLSyclContextRef and DPCTLSyclDeviceRef. - Remove the DPCTLDeviceMgr_AreEq function and related internal helper functions for hashing SYCL devices. - Remove DeviceWrapper class inside dpctl_sycl_device_manager.cpp. - Remove DPCTL_DeviceAndContextPair type form C API. - Make equivalent changes to the Python API.
1 parent f0ada84 commit 82ea16d

8 files changed

+105
-224
lines changed

dpctl-capi/include/dpctl_sycl_device_manager.h

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -39,44 +39,10 @@ DPCTL_C_EXTERN_C_BEGIN
3939
* @defgroup DeviceManager Device management helper functions
4040
*/
4141

42-
/*!
43-
* @brief Contains a #DPCTLSyclDeviceRef and #DPCTLSyclContextRef 2-tuple that
44-
* contains a sycl::device and a sycl::context associated with that device.
45-
*/
46-
typedef struct DPCTL_API DeviceAndContextPair
47-
{
48-
DPCTLSyclDeviceRef DRef;
49-
DPCTLSyclContextRef CRef;
50-
} DPCTL_DeviceAndContextPair;
51-
5242
// Declares a set of types abd functions to deal with vectors of
5343
// DPCTLSyclDeviceRef. Refer dpctl_vector_macros.h
5444
DPCTL_DECLARE_VECTOR(Device)
5545

56-
/*!
57-
* @brief Checks if two ::DPCTLSyclDeviceRef objects point to the same
58-
* sycl::device.
59-
*
60-
* DPC++ 2021.1.2 has some bugs that prevent the equality of sycl::device
61-
* objects to work correctly. The DPCTLDeviceMgr_AreEq implements a workaround
62-
* to check if two sycl::device pointers are equivalent. Since, DPC++ uses
63-
* std::shared_pointer wrappers for sycl::device objects we check if the raw
64-
* pointer (shared_pointer.get()) for each device are the same. One caveat is
65-
* that the trick works only for non-host devices. The function evaluates host
66-
* devices separately and always assumes that all host devices are equivalent,
67-
* while checking for the raw pointer equivalent for all other types of devices.
68-
* The workaround will be removed once DPC++ is fixed to correctly check device
69-
* equivalence.
70-
*
71-
* @param DRef1 First opaque pointer to a sycl device.
72-
* @param DRef2 Second opaque pointer to a sycl device.
73-
* @return True if the underlying sycl::device are same, false otherwise.
74-
* @ingroup DeviceManager
75-
*/
76-
DPCTL_API
77-
bool DPCTLDeviceMgr_AreEq(__dpctl_keep const DPCTLSyclDeviceRef DRef1,
78-
__dpctl_keep const DPCTLSyclDeviceRef DRef2);
79-
8046
/*!
8147
* @brief Returns a pointer to a std::vector<sycl::DPCTLSyclDeviceRef>
8248
* containing the set of ::DPCTLSyclDeviceRef pointers matching the passed in
@@ -110,25 +76,20 @@ __dpctl_give DPCTLDeviceVectorRef
11076
DPCTLDeviceMgr_GetDevices(int device_identifier);
11177

11278
/*!
113-
* @brief Returns the default sycl context inside an opaque DPCTLSyclContextRef
114-
* pointer for the DPCTLSyclDeviceRef input argument.
79+
* @brief If the DPCTLSyclDeviceRef argument is a root device, then this
80+
* function returns a cached default SYCL context for that device.
11581
*
11682
* @param DRef A pointer to a sycl::device that will be used to
11783
* search an internal map containing a cached "default"
11884
* sycl::context for the device.
119-
* @return A #DPCTL_DeviceAndContextPair struct containing the cached
120-
* #DPCTLSyclContextRef associated with the #DPCTLSyclDeviceRef argument passed
121-
* to the function. The DPCTL_DeviceAndContextPair also contains a
122-
* #DPCTLSyclDeviceRef pointer pointing to the same device as the input
123-
* #DPCTLSyclDeviceRef. The returned #DPCTLSyclDeviceRef was cached along with
124-
* the #DPCTLSyclContextRef. This is a workaround till device equality is
125-
* properly fixed in DPC++. If the #DPCTLSyclDeviceRef is not found in the cache
126-
* then DPCTL_DeviceAndContextPair contains a pair of nullptr.
85+
* @return A DPCTLSyclContextRef associated with the #DPCTLSyclDeviceRef
86+
* argument passed to the function. If the #DPCTLSyclDeviceRef is not found in
87+
* the cache, then returns a nullptr.
12788
* @ingroup DeviceManager
12889
*/
12990
DPCTL_API
130-
DPCTL_DeviceAndContextPair DPCTLDeviceMgr_GetDeviceAndContextPair(
131-
__dpctl_keep const DPCTLSyclDeviceRef DRef);
91+
DPCTLSyclContextRef
92+
DPCTLDeviceMgr_GetCachedContext(__dpctl_keep const DPCTLSyclDeviceRef DRef);
13293

13394
/*!
13495
* @brief Get the number of available devices for given backend and device type

dpctl-capi/source/dpctl_sycl_device_interface.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,12 @@ bool DPCTLDevice_IsHostUnifiedMemory(__dpctl_keep const DPCTLSyclDeviceRef DRef)
366366
bool DPCTLDevice_AreEq(__dpctl_keep const DPCTLSyclDeviceRef DRef1,
367367
__dpctl_keep const DPCTLSyclDeviceRef DRef2)
368368
{
369-
// Note: DPCPP does not yet support device equality of the form:
370-
// *unwrap(DevRef1) == *unwrap(DevRef2). Till DPCPP is fixed we use the
371-
// custom equality checker implemented inside DPCTLDeviceMgr.
372-
return DPCTLDeviceMgr_AreEq(DRef1, DRef2);
369+
auto D1 = unwrap(DRef1);
370+
auto D2 = unwrap(DRef2);
371+
if (D1 && D2)
372+
return *D1 == *D2;
373+
else
374+
return false;
373375
}
374376

375377
bool DPCTLDevice_HasAspect(__dpctl_keep const DPCTLSyclDeviceRef DRef,

dpctl-capi/source/dpctl_sycl_device_manager.cpp

Lines changed: 28 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,6 @@ namespace
4141
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(device, DPCTLSyclDeviceRef)
4242
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(context, DPCTLSyclContextRef)
4343

44-
/* Checks if two devices are equal based on the underlying native pointer.
45-
*/
46-
bool deviceEqChecker(const device &D1, const device &D2)
47-
{
48-
if (D1.is_host() && D2.is_host()) {
49-
return true;
50-
}
51-
else if ((D1.is_host() && !D2.is_host()) || (D2.is_host() && !D1.is_host()))
52-
{
53-
return false;
54-
}
55-
else {
56-
return D1.get() == D2.get();
57-
}
58-
}
59-
6044
/*
6145
* Helper function to print the metadata for a sycl::device.
6246
*/
@@ -80,64 +64,9 @@ void print_device_info(const device &Device)
8064
std::cout << ss.str();
8165
}
8266

83-
/*
84-
* Helper class to store DPCTLSyclDeviceType and DPCTLSyclBackendType attributes
85-
* for a device along with the SYCL device.
86-
*/
87-
struct DeviceWrapper
88-
{
89-
device SyclDevice;
90-
DPCTLSyclBackendType Bty;
91-
DPCTLSyclDeviceType Dty;
92-
93-
DeviceWrapper(const device &Device)
94-
: SyclDevice(Device), Bty(DPCTL_SyclBackendToDPCTLBackendType(
95-
Device.get_platform().get_backend())),
96-
Dty(DPCTL_SyclDeviceTypeToDPCTLDeviceType(
97-
Device.get_info<info::device::device_type>()))
98-
{
99-
}
100-
101-
// The constructor is provided for convenience, so that we do not have to
102-
// lookup the BackendType and DeviceType if not needed.
103-
DeviceWrapper(const device &Device,
104-
DPCTLSyclBackendType Bty,
105-
DPCTLSyclDeviceType Dty)
106-
: SyclDevice(Device), Bty(Bty), Dty(Dty)
107-
{
108-
}
109-
};
110-
111-
auto getHash(const device &d)
112-
{
113-
if (d.is_host()) {
114-
return std::hash<unsigned long long>{}(-1);
115-
}
116-
else {
117-
return std::hash<decltype(d.get())>{}(d.get());
118-
}
119-
}
120-
121-
struct DeviceHasher
122-
{
123-
size_t operator()(const DeviceWrapper &d) const
124-
{
125-
return getHash(d.SyclDevice);
126-
}
127-
};
128-
129-
struct DeviceEqPred
130-
{
131-
bool operator()(const DeviceWrapper &d1, const DeviceWrapper &d2) const
132-
{
133-
return deviceEqChecker(d1.SyclDevice, d2.SyclDevice);
134-
}
135-
};
136-
13767
struct DeviceCacheBuilder
13868
{
139-
using DeviceCache =
140-
std::unordered_map<DeviceWrapper, context, DeviceHasher, DeviceEqPred>;
69+
using DeviceCache = std::unordered_map<device, context>;
14170
/* This function implements a workaround to the current lack of a default
14271
* context per root device in DPC++. The map stores a "default" context for
14372
* each root device, and the QMgrHelper uses the map whenever it creates a
@@ -181,40 +110,29 @@ struct DeviceCacheBuilder
181110
#include "dpctl_vector_templ.cpp"
182111
#undef EL
183112

184-
bool DPCTLDeviceMgr_AreEq(__dpctl_keep const DPCTLSyclDeviceRef DRef1,
185-
__dpctl_keep const DPCTLSyclDeviceRef DRef2)
113+
DPCTLSyclContextRef
114+
DPCTLDeviceMgr_GetCachedContext(__dpctl_keep const DPCTLSyclDeviceRef DRef)
186115
{
187-
auto D1 = unwrap(DRef1);
188-
auto D2 = unwrap(DRef2);
189-
if (D1 && D2)
190-
return deviceEqChecker(*D1, *D2);
191-
else
192-
return false;
193-
}
116+
DPCTLSyclContextRef CRef = nullptr;
194117

195-
DPCTL_DeviceAndContextPair DPCTLDeviceMgr_GetDeviceAndContextPair(
196-
__dpctl_keep const DPCTLSyclDeviceRef DRef)
197-
{
198-
DPCTL_DeviceAndContextPair rPair{nullptr, nullptr};
199118
auto Device = unwrap(DRef);
200-
if (!Device) {
201-
return rPair;
202-
}
203-
DeviceWrapper DWrapper{*Device, DPCTLSyclBackendType::DPCTL_UNKNOWN_BACKEND,
204-
DPCTLSyclDeviceType::DPCTL_UNKNOWN_DEVICE};
119+
if (!Device)
120+
return CRef;
121+
205122
auto &cache = DeviceCacheBuilder::getDeviceCache();
206-
auto entry = cache.find(DWrapper);
123+
auto entry = cache.find(*Device);
207124
if (entry != cache.end()) {
208125
try {
209-
rPair.DRef = wrap(new device(entry->first.SyclDevice));
210-
rPair.CRef = wrap(new context(entry->second));
126+
CRef = wrap(new context(entry->second));
211127
} catch (std::bad_alloc const &ba) {
212128
std::cerr << ba.what() << std::endl;
213-
rPair.DRef = nullptr;
214-
rPair.CRef = nullptr;
129+
CRef = nullptr;
215130
}
216131
}
217-
return rPair;
132+
else {
133+
std::cerr << "No cached default context for device" << std::endl;
134+
}
135+
return CRef;
218136
}
219137

220138
__dpctl_give DPCTLDeviceVectorRef
@@ -228,12 +146,14 @@ DPCTLDeviceMgr_GetDevices(int device_identifier)
228146
return nullptr;
229147
}
230148
auto &cache = DeviceCacheBuilder::getDeviceCache();
231-
Devices->reserve(cache.size());
149+
232150
for (const auto &entry : cache) {
233-
if ((device_identifier & entry.first.Bty) &&
234-
(device_identifier & entry.first.Dty))
235-
{
236-
Devices->emplace_back(wrap(new device(entry.first.SyclDevice)));
151+
auto Bty(DPCTL_SyclBackendToDPCTLBackendType(
152+
entry.first.get_platform().get_backend()));
153+
auto Dty(DPCTL_SyclDeviceTypeToDPCTLDeviceType(
154+
entry.first.get_info<info::device::device_type>()));
155+
if ((device_identifier & Bty) && (device_identifier & Dty)) {
156+
Devices->emplace_back(wrap(new device(entry.first)));
237157
}
238158
}
239159
// the wrap function is defined inside dpctl_vector_templ.cpp
@@ -248,11 +168,14 @@ size_t DPCTLDeviceMgr_GetNumDevices(int device_identifier)
248168
{
249169
size_t nDevices = 0;
250170
auto &cache = DeviceCacheBuilder::getDeviceCache();
251-
for (const auto &entry : cache)
252-
if ((device_identifier & entry.first.Bty) &&
253-
(device_identifier & entry.first.Dty))
171+
for (const auto &entry : cache) {
172+
auto Bty(DPCTL_SyclBackendToDPCTLBackendType(
173+
entry.first.get_platform().get_backend()));
174+
auto Dty(DPCTL_SyclDeviceTypeToDPCTLDeviceType(
175+
entry.first.get_info<info::device::device_type>()));
176+
if ((device_identifier & Bty) && (device_identifier & Dty))
254177
++nDevices;
255-
178+
}
256179
return nDevices;
257180
}
258181

dpctl-capi/source/dpctl_sycl_queue_interface.cpp

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,13 @@ std::unique_ptr<property_list> create_property_list(int properties)
137137
}
138138

139139
__dpctl_give DPCTLSyclQueueRef
140-
getQueueImpl(__dpctl_take DPCTLSyclContextRef cRef,
141-
__dpctl_take DPCTLSyclDeviceRef dRef,
140+
getQueueImpl(__dpctl_keep DPCTLSyclContextRef cRef,
141+
__dpctl_keep DPCTLSyclDeviceRef dRef,
142142
error_handler_callback *handler,
143143
int properties)
144144
{
145145
DPCTLSyclQueueRef qRef = nullptr;
146146
qRef = DPCTLQueue_Create(cRef, dRef, handler, properties);
147-
DPCTLContext_Delete(cRef);
148-
DPCTLDevice_Delete(dRef);
149-
150147
return qRef;
151148
}
152149

@@ -216,36 +213,37 @@ DPCTLQueue_Create(__dpctl_keep const DPCTLSyclContextRef CRef,
216213
}
217214

218215
__dpctl_give DPCTLSyclQueueRef
219-
DPCTLQueue_CreateForDevice(__dpctl_keep const DPCTLSyclDeviceRef dRef,
216+
DPCTLQueue_CreateForDevice(__dpctl_keep const DPCTLSyclDeviceRef DRef,
220217
error_handler_callback *handler,
221218
int properties)
222219
{
223-
DPCTLSyclQueueRef qRef = nullptr;
224-
auto Device = unwrap(dRef);
220+
DPCTLSyclContextRef CRef = nullptr;
221+
DPCTLSyclQueueRef QRef = nullptr;
222+
auto Device = unwrap(DRef);
225223

226224
if (!Device) {
227225
std::cerr << "Cannot create queue from NULL device reference.\n";
228-
return qRef;
226+
return QRef;
229227
}
230-
auto cached = DPCTLDeviceMgr_GetDeviceAndContextPair(dRef);
231-
if (cached.CRef) {
232-
qRef = getQueueImpl(cached.CRef, cached.DRef, handler, properties);
233-
}
234-
// We only cache contexts for root devices. If the dRef argument points to
235-
// a sub-device, then the queue manager allocates a new context and creates
236-
// a new queue to retrun to caller. Note that any context for a sub-device
237-
// is not cached.
238-
else {
228+
// Check if a cached default context exists for the device.
229+
CRef = DPCTLDeviceMgr_GetCachedContext(DRef);
230+
// If a cached default context was found, that context will be used to use
231+
// create the new queue. When a default cached context was not found, as
232+
// will be the case for non-root devices, i.e., sub-devices, a new context
233+
// will be allocated. Note that any newly allocated context is not cached.
234+
if (!CRef) {
239235
try {
240-
auto CRef = wrap(new context(*Device));
241-
auto DRef_copy = wrap(new device(*Device));
242-
qRef = getQueueImpl(CRef, DRef_copy, handler, properties);
236+
CRef = wrap(new context(*Device));
243237
} catch (std::bad_alloc const &ba) {
244238
std::cerr << ba.what() << std::endl;
239+
return QRef;
245240
}
246241
}
247-
248-
return qRef;
242+
// At this point we have a valid context and the queue can be allocated.
243+
QRef = getQueueImpl(CRef, DRef, handler, properties);
244+
// Free the context
245+
DPCTLContext_Delete(CRef);
246+
return QRef;
249247
}
250248

251249
/*!
@@ -304,9 +302,20 @@ DPCTLSyclBackendType DPCTLQueue_GetBackend(__dpctl_keep DPCTLSyclQueueRef QRef)
304302
__dpctl_give DPCTLSyclDeviceRef
305303
DPCTLQueue_GetDevice(__dpctl_keep const DPCTLSyclQueueRef QRef)
306304
{
305+
DPCTLSyclDeviceRef DRef = nullptr;
307306
auto Q = unwrap(QRef);
308-
auto Device = new device(Q->get_device());
309-
return wrap(Device);
307+
if (Q) {
308+
try {
309+
auto Device = new device(Q->get_device());
310+
DRef = wrap(Device);
311+
} catch (std::bad_alloc const &ba) {
312+
std::cerr << ba.what() << '\n';
313+
}
314+
}
315+
else {
316+
std::cerr << "Could not get the device for this queue.\n";
317+
}
318+
return DRef;
310319
}
311320

312321
__dpctl_give DPCTLSyclContextRef
@@ -438,7 +447,8 @@ DPCTLQueue_SubmitNDRange(__dpctl_keep const DPCTLSyclKernelRef KRef,
438447

439448
void DPCTLQueue_Wait(__dpctl_keep DPCTLSyclQueueRef QRef)
440449
{
441-
// \todo what happens if the QRef is null or a pointer to a valid sycl queue
450+
// \todo what happens if the QRef is null or a pointer to a valid sycl
451+
// queue
442452
auto SyclQueue = unwrap(QRef);
443453
SyclQueue->wait();
444454
}

0 commit comments

Comments
 (0)