Skip to content

Remove workarounds to make device equality work. #338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 7 additions & 46 deletions dpctl-capi/include/dpctl_sycl_device_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,44 +39,10 @@ DPCTL_C_EXTERN_C_BEGIN
* @defgroup DeviceManager Device management helper functions
*/

/*!
* @brief Contains a #DPCTLSyclDeviceRef and #DPCTLSyclContextRef 2-tuple that
* contains a sycl::device and a sycl::context associated with that device.
*/
typedef struct DPCTL_API DeviceAndContextPair
{
DPCTLSyclDeviceRef DRef;
DPCTLSyclContextRef CRef;
} DPCTL_DeviceAndContextPair;

// Declares a set of types abd functions to deal with vectors of
// DPCTLSyclDeviceRef. Refer dpctl_vector_macros.h
DPCTL_DECLARE_VECTOR(Device)

/*!
* @brief Checks if two ::DPCTLSyclDeviceRef objects point to the same
* sycl::device.
*
* DPC++ 2021.1.2 has some bugs that prevent the equality of sycl::device
* objects to work correctly. The DPCTLDeviceMgr_AreEq implements a workaround
* to check if two sycl::device pointers are equivalent. Since, DPC++ uses
* std::shared_pointer wrappers for sycl::device objects we check if the raw
* pointer (shared_pointer.get()) for each device are the same. One caveat is
* that the trick works only for non-host devices. The function evaluates host
* devices separately and always assumes that all host devices are equivalent,
* while checking for the raw pointer equivalent for all other types of devices.
* The workaround will be removed once DPC++ is fixed to correctly check device
* equivalence.
*
* @param DRef1 First opaque pointer to a sycl device.
* @param DRef2 Second opaque pointer to a sycl device.
* @return True if the underlying sycl::device are same, false otherwise.
* @ingroup DeviceManager
*/
DPCTL_API
bool DPCTLDeviceMgr_AreEq(__dpctl_keep const DPCTLSyclDeviceRef DRef1,
__dpctl_keep const DPCTLSyclDeviceRef DRef2);

/*!
* @brief Returns a pointer to a std::vector<sycl::DPCTLSyclDeviceRef>
* containing the set of ::DPCTLSyclDeviceRef pointers matching the passed in
Expand Down Expand Up @@ -110,25 +76,20 @@ __dpctl_give DPCTLDeviceVectorRef
DPCTLDeviceMgr_GetDevices(int device_identifier);

/*!
* @brief Returns the default sycl context inside an opaque DPCTLSyclContextRef
* pointer for the DPCTLSyclDeviceRef input argument.
* @brief If the DPCTLSyclDeviceRef argument is a root device, then this
* function returns a cached default SYCL context for that device.
*
* @param DRef A pointer to a sycl::device that will be used to
* search an internal map containing a cached "default"
* sycl::context for the device.
* @return A #DPCTL_DeviceAndContextPair struct containing the cached
* #DPCTLSyclContextRef associated with the #DPCTLSyclDeviceRef argument passed
* to the function. The DPCTL_DeviceAndContextPair also contains a
* #DPCTLSyclDeviceRef pointer pointing to the same device as the input
* #DPCTLSyclDeviceRef. The returned #DPCTLSyclDeviceRef was cached along with
* the #DPCTLSyclContextRef. This is a workaround till device equality is
* properly fixed in DPC++. If the #DPCTLSyclDeviceRef is not found in the cache
* then DPCTL_DeviceAndContextPair contains a pair of nullptr.
* @return A DPCTLSyclContextRef associated with the #DPCTLSyclDeviceRef
* argument passed to the function. If the #DPCTLSyclDeviceRef is not found in
* the cache, then returns a nullptr.
* @ingroup DeviceManager
*/
DPCTL_API
DPCTL_DeviceAndContextPair DPCTLDeviceMgr_GetDeviceAndContextPair(
__dpctl_keep const DPCTLSyclDeviceRef DRef);
DPCTLSyclContextRef
DPCTLDeviceMgr_GetCachedContext(__dpctl_keep const DPCTLSyclDeviceRef DRef);

/*!
* @brief Get the number of available devices for given backend and device type
Expand Down
10 changes: 6 additions & 4 deletions dpctl-capi/source/dpctl_sycl_device_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,12 @@ bool DPCTLDevice_IsHostUnifiedMemory(__dpctl_keep const DPCTLSyclDeviceRef DRef)
bool DPCTLDevice_AreEq(__dpctl_keep const DPCTLSyclDeviceRef DRef1,
__dpctl_keep const DPCTLSyclDeviceRef DRef2)
{
// Note: DPCPP does not yet support device equality of the form:
// *unwrap(DevRef1) == *unwrap(DevRef2). Till DPCPP is fixed we use the
// custom equality checker implemented inside DPCTLDeviceMgr.
return DPCTLDeviceMgr_AreEq(DRef1, DRef2);
auto D1 = unwrap(DRef1);
auto D2 = unwrap(DRef2);
if (D1 && D2)
return *D1 == *D2;
else
return false;
}

bool DPCTLDevice_HasAspect(__dpctl_keep const DPCTLSyclDeviceRef DRef,
Expand Down
133 changes: 28 additions & 105 deletions dpctl-capi/source/dpctl_sycl_device_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,6 @@ namespace
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(device, DPCTLSyclDeviceRef)
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(context, DPCTLSyclContextRef)

/* Checks if two devices are equal based on the underlying native pointer.
*/
bool deviceEqChecker(const device &D1, const device &D2)
{
if (D1.is_host() && D2.is_host()) {
return true;
}
else if ((D1.is_host() && !D2.is_host()) || (D2.is_host() && !D1.is_host()))
{
return false;
}
else {
return D1.get() == D2.get();
}
}

/*
* Helper function to print the metadata for a sycl::device.
*/
Expand All @@ -80,64 +64,9 @@ void print_device_info(const device &Device)
std::cout << ss.str();
}

/*
* Helper class to store DPCTLSyclDeviceType and DPCTLSyclBackendType attributes
* for a device along with the SYCL device.
*/
struct DeviceWrapper
{
device SyclDevice;
DPCTLSyclBackendType Bty;
DPCTLSyclDeviceType Dty;

DeviceWrapper(const device &Device)
: SyclDevice(Device), Bty(DPCTL_SyclBackendToDPCTLBackendType(
Device.get_platform().get_backend())),
Dty(DPCTL_SyclDeviceTypeToDPCTLDeviceType(
Device.get_info<info::device::device_type>()))
{
}

// The constructor is provided for convenience, so that we do not have to
// lookup the BackendType and DeviceType if not needed.
DeviceWrapper(const device &Device,
DPCTLSyclBackendType Bty,
DPCTLSyclDeviceType Dty)
: SyclDevice(Device), Bty(Bty), Dty(Dty)
{
}
};

auto getHash(const device &d)
{
if (d.is_host()) {
return std::hash<unsigned long long>{}(-1);
}
else {
return std::hash<decltype(d.get())>{}(d.get());
}
}

struct DeviceHasher
{
size_t operator()(const DeviceWrapper &d) const
{
return getHash(d.SyclDevice);
}
};

struct DeviceEqPred
{
bool operator()(const DeviceWrapper &d1, const DeviceWrapper &d2) const
{
return deviceEqChecker(d1.SyclDevice, d2.SyclDevice);
}
};

struct DeviceCacheBuilder
{
using DeviceCache =
std::unordered_map<DeviceWrapper, context, DeviceHasher, DeviceEqPred>;
using DeviceCache = std::unordered_map<device, context>;
/* This function implements a workaround to the current lack of a default
* context per root device in DPC++. The map stores a "default" context for
* each root device, and the QMgrHelper uses the map whenever it creates a
Expand Down Expand Up @@ -181,40 +110,29 @@ struct DeviceCacheBuilder
#include "dpctl_vector_templ.cpp"
#undef EL

bool DPCTLDeviceMgr_AreEq(__dpctl_keep const DPCTLSyclDeviceRef DRef1,
__dpctl_keep const DPCTLSyclDeviceRef DRef2)
DPCTLSyclContextRef
DPCTLDeviceMgr_GetCachedContext(__dpctl_keep const DPCTLSyclDeviceRef DRef)
{
auto D1 = unwrap(DRef1);
auto D2 = unwrap(DRef2);
if (D1 && D2)
return deviceEqChecker(*D1, *D2);
else
return false;
}
DPCTLSyclContextRef CRef = nullptr;

DPCTL_DeviceAndContextPair DPCTLDeviceMgr_GetDeviceAndContextPair(
__dpctl_keep const DPCTLSyclDeviceRef DRef)
{
DPCTL_DeviceAndContextPair rPair{nullptr, nullptr};
auto Device = unwrap(DRef);
if (!Device) {
return rPair;
}
DeviceWrapper DWrapper{*Device, DPCTLSyclBackendType::DPCTL_UNKNOWN_BACKEND,
DPCTLSyclDeviceType::DPCTL_UNKNOWN_DEVICE};
if (!Device)
return CRef;

auto &cache = DeviceCacheBuilder::getDeviceCache();
auto entry = cache.find(DWrapper);
auto entry = cache.find(*Device);
if (entry != cache.end()) {
try {
rPair.DRef = wrap(new device(entry->first.SyclDevice));
rPair.CRef = wrap(new context(entry->second));
CRef = wrap(new context(entry->second));
} catch (std::bad_alloc const &ba) {
std::cerr << ba.what() << std::endl;
rPair.DRef = nullptr;
rPair.CRef = nullptr;
CRef = nullptr;
}
}
return rPair;
else {
std::cerr << "No cached default context for device" << std::endl;
}
return CRef;
}

__dpctl_give DPCTLDeviceVectorRef
Expand All @@ -228,12 +146,14 @@ DPCTLDeviceMgr_GetDevices(int device_identifier)
return nullptr;
}
auto &cache = DeviceCacheBuilder::getDeviceCache();
Devices->reserve(cache.size());

for (const auto &entry : cache) {
if ((device_identifier & entry.first.Bty) &&
(device_identifier & entry.first.Dty))
{
Devices->emplace_back(wrap(new device(entry.first.SyclDevice)));
auto Bty(DPCTL_SyclBackendToDPCTLBackendType(
entry.first.get_platform().get_backend()));
auto Dty(DPCTL_SyclDeviceTypeToDPCTLDeviceType(
entry.first.get_info<info::device::device_type>()));
if ((device_identifier & Bty) && (device_identifier & Dty)) {
Devices->emplace_back(wrap(new device(entry.first)));
}
}
// the wrap function is defined inside dpctl_vector_templ.cpp
Expand All @@ -248,11 +168,14 @@ size_t DPCTLDeviceMgr_GetNumDevices(int device_identifier)
{
size_t nDevices = 0;
auto &cache = DeviceCacheBuilder::getDeviceCache();
for (const auto &entry : cache)
if ((device_identifier & entry.first.Bty) &&
(device_identifier & entry.first.Dty))
for (const auto &entry : cache) {
auto Bty(DPCTL_SyclBackendToDPCTLBackendType(
entry.first.get_platform().get_backend()));
auto Dty(DPCTL_SyclDeviceTypeToDPCTLDeviceType(
entry.first.get_info<info::device::device_type>()));
if ((device_identifier & Bty) && (device_identifier & Dty))
++nDevices;

}
return nDevices;
}

Expand Down
62 changes: 36 additions & 26 deletions dpctl-capi/source/dpctl_sycl_queue_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,13 @@ std::unique_ptr<property_list> create_property_list(int properties)
}

__dpctl_give DPCTLSyclQueueRef
getQueueImpl(__dpctl_take DPCTLSyclContextRef cRef,
__dpctl_take DPCTLSyclDeviceRef dRef,
getQueueImpl(__dpctl_keep DPCTLSyclContextRef cRef,
__dpctl_keep DPCTLSyclDeviceRef dRef,
error_handler_callback *handler,
int properties)
{
DPCTLSyclQueueRef qRef = nullptr;
qRef = DPCTLQueue_Create(cRef, dRef, handler, properties);
DPCTLContext_Delete(cRef);
DPCTLDevice_Delete(dRef);

return qRef;
}

Expand Down Expand Up @@ -216,36 +213,37 @@ DPCTLQueue_Create(__dpctl_keep const DPCTLSyclContextRef CRef,
}

__dpctl_give DPCTLSyclQueueRef
DPCTLQueue_CreateForDevice(__dpctl_keep const DPCTLSyclDeviceRef dRef,
DPCTLQueue_CreateForDevice(__dpctl_keep const DPCTLSyclDeviceRef DRef,
error_handler_callback *handler,
int properties)
{
DPCTLSyclQueueRef qRef = nullptr;
auto Device = unwrap(dRef);
DPCTLSyclContextRef CRef = nullptr;
DPCTLSyclQueueRef QRef = nullptr;
auto Device = unwrap(DRef);

if (!Device) {
std::cerr << "Cannot create queue from NULL device reference.\n";
return qRef;
return QRef;
}
auto cached = DPCTLDeviceMgr_GetDeviceAndContextPair(dRef);
if (cached.CRef) {
qRef = getQueueImpl(cached.CRef, cached.DRef, handler, properties);
}
// We only cache contexts for root devices. If the dRef argument points to
// a sub-device, then the queue manager allocates a new context and creates
// a new queue to retrun to caller. Note that any context for a sub-device
// is not cached.
else {
// Check if a cached default context exists for the device.
CRef = DPCTLDeviceMgr_GetCachedContext(DRef);
// If a cached default context was found, that context will be used to use
// create the new queue. When a default cached context was not found, as
// will be the case for non-root devices, i.e., sub-devices, a new context
// will be allocated. Note that any newly allocated context is not cached.
if (!CRef) {
try {
auto CRef = wrap(new context(*Device));
auto DRef_copy = wrap(new device(*Device));
qRef = getQueueImpl(CRef, DRef_copy, handler, properties);
CRef = wrap(new context(*Device));
} catch (std::bad_alloc const &ba) {
std::cerr << ba.what() << std::endl;
return QRef;
}
}

return qRef;
// At this point we have a valid context and the queue can be allocated.
QRef = getQueueImpl(CRef, DRef, handler, properties);
// Free the context
DPCTLContext_Delete(CRef);
return QRef;
}

/*!
Expand Down Expand Up @@ -304,9 +302,20 @@ DPCTLSyclBackendType DPCTLQueue_GetBackend(__dpctl_keep DPCTLSyclQueueRef QRef)
__dpctl_give DPCTLSyclDeviceRef
DPCTLQueue_GetDevice(__dpctl_keep const DPCTLSyclQueueRef QRef)
{
DPCTLSyclDeviceRef DRef = nullptr;
auto Q = unwrap(QRef);
auto Device = new device(Q->get_device());
return wrap(Device);
if (Q) {
try {
auto Device = new device(Q->get_device());
DRef = wrap(Device);
} catch (std::bad_alloc const &ba) {
std::cerr << ba.what() << '\n';
}
}
else {
std::cerr << "Could not get the device for this queue.\n";
}
return DRef;
}

__dpctl_give DPCTLSyclContextRef
Expand Down Expand Up @@ -438,7 +447,8 @@ DPCTLQueue_SubmitNDRange(__dpctl_keep const DPCTLSyclKernelRef KRef,

void DPCTLQueue_Wait(__dpctl_keep DPCTLSyclQueueRef QRef)
{
// \todo what happens if the QRef is null or a pointer to a valid sycl queue
// \todo what happens if the QRef is null or a pointer to a valid sycl
// queue
auto SyclQueue = unwrap(QRef);
SyclQueue->wait();
}
Expand Down
Loading