Skip to content

[SYCL] Refactoring of queue classes #2205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions sycl/include/CL/sycl/queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ class __SYCL_EXPORT queue {
queue(cl_command_queue ClQueue, const context &SyclContext,
const async_handler &AsyncHandler = {});

queue(const queue &rhs) = default;
queue(const queue &RHS) = default;

queue(queue &&rhs) = default;
queue(queue &&RHS) = default;

queue &operator=(const queue &rhs) = default;
queue &operator=(const queue &RHS) = default;

queue &operator=(queue &&rhs) = default;
queue &operator=(queue &&RHS) = default;

bool operator==(const queue &rhs) const { return impl == rhs.impl; }
bool operator==(const queue &RHS) const { return impl == RHS.impl; }

bool operator!=(const queue &rhs) const { return !(*this == rhs); }
bool operator!=(const queue &RHS) const { return !(*this == RHS); }

/// \return a valid instance of OpenCL queue, which is retained before being
/// returned.
Expand Down Expand Up @@ -317,7 +317,7 @@ class __SYCL_EXPORT queue {
/// \return a copy of the property of type PropertyT that the queue was
/// constructed with. If the queue was not constructed with the PropertyT
/// property, an invalid_object_error SYCL exception.
template <typename propertyT> propertyT get_property() const;
template <typename PropertyT> PropertyT get_property() const;

/// Fills the memory pointed by a USM pointer with the value specified.
///
Expand Down Expand Up @@ -900,10 +900,10 @@ class __SYCL_EXPORT queue {

namespace std {
template <> struct hash<cl::sycl::queue> {
size_t operator()(const cl::sycl::queue &q) const {
size_t operator()(const cl::sycl::queue &Q) const {
return std::hash<
cl::sycl::shared_ptr_class<cl::sycl::detail::queue_impl>>()(
cl::sycl::detail::getSyclObjImpl(q));
cl::sycl::detail::getSyclObjImpl(Q));
}
};
} // namespace std
94 changes: 49 additions & 45 deletions sycl/source/detail/queue_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <detail/queue_impl.hpp>

#include <cstring>
#include <utility>

#ifdef XPTI_ENABLE_INSTRUMENTATION
#include "xpti_trace_framework.hpp"
Expand All @@ -27,7 +28,7 @@ template <> cl_uint queue_impl::get_info<info::queue::reference_count>() const {
RT::PiResult result = PI_SUCCESS;
if (!is_host())
getPlugin().call<PiApiKind::piQueueGetInfo>(
MCommandQueue, PI_QUEUE_INFO_REFERENCE_COUNT, sizeof(result), &result,
MQueues[0], PI_QUEUE_INFO_REFERENCE_COUNT, sizeof(result), &result,
nullptr);
return result;
}
Expand All @@ -40,78 +41,80 @@ template <> device queue_impl::get_info<info::queue::device>() const {
return get_device();
}

static event prepareUSMEvent(shared_ptr_class<detail::queue_impl> QueueImpl,
RT::PiEvent NativeEvent) {
static event
prepareUSMEvent(const shared_ptr_class<detail::queue_impl> &QueueImpl,
RT::PiEvent NativeEvent) {
auto EventImpl = std::make_shared<detail::event_impl>(QueueImpl);
EventImpl->getHandleRef() = NativeEvent;
EventImpl->setContextImpl(detail::getSyclObjImpl(QueueImpl->get_context()));
return detail::createSyclObjFromImpl<event>(EventImpl);
}

event queue_impl::memset(shared_ptr_class<detail::queue_impl> Impl, void *Ptr,
int Value, size_t Count) {
context Context = get_context();
RT::PiEvent NativeEvent = nullptr;
MemoryManager::fill_usm(Ptr, Impl, Count, Value, /*DepEvents*/ {},
event queue_impl::memset(const shared_ptr_class<detail::queue_impl> &Self,
void *Ptr, int Value, size_t Count) {
RT::PiEvent NativeEvent{};
MemoryManager::fill_usm(Ptr, Self, Count, Value, /*DepEvents*/ {},
NativeEvent);

if (Context.is_host())
if (MContext->is_host())
return event();

event ResEvent = prepareUSMEvent(Impl, NativeEvent);
event ResEvent = prepareUSMEvent(Self, NativeEvent);
addUSMEvent(ResEvent);
return ResEvent;
}

event queue_impl::memcpy(shared_ptr_class<detail::queue_impl> Impl, void *Dest,
const void *Src, size_t Count) {
context Context = get_context();
RT::PiEvent NativeEvent = nullptr;
MemoryManager::copy_usm(Src, Impl, Count, Dest, /*DepEvents*/ {},
event queue_impl::memcpy(const shared_ptr_class<detail::queue_impl> &Self,
void *Dest, const void *Src, size_t Count) {
RT::PiEvent NativeEvent{};
MemoryManager::copy_usm(Src, Self, Count, Dest, /*DepEvents*/ {},
NativeEvent);

if (Context.is_host())
if (MContext->is_host())
return event();

event ResEvent = prepareUSMEvent(Impl, NativeEvent);
event ResEvent = prepareUSMEvent(Self, NativeEvent);
addUSMEvent(ResEvent);
return ResEvent;
}

event queue_impl::mem_advise(shared_ptr_class<detail::queue_impl> Impl,
event queue_impl::mem_advise(const shared_ptr_class<detail::queue_impl> &Self,
const void *Ptr, size_t Length,
pi_mem_advice Advice) {
context Context = get_context();
if (Context.is_host()) {
if (MContext->is_host()) {
return event();
}

// non-Host device
RT::PiEvent NativeEvent = nullptr;
RT::PiEvent NativeEvent{};
const detail::plugin &Plugin = getPlugin();
Plugin.call<PiApiKind::piextUSMEnqueueMemAdvise>(getHandleRef(), Ptr, Length,
Advice, &NativeEvent);

event ResEvent = prepareUSMEvent(Impl, NativeEvent);
event ResEvent = prepareUSMEvent(Self, NativeEvent);
addUSMEvent(ResEvent);
return ResEvent;
}

void queue_impl::addEvent(event Event) {
void queue_impl::addEvent(const event &Event) {
std::weak_ptr<event_impl> EventWeakPtr{getSyclObjImpl(Event)};
std::lock_guard<mutex_class> Guard(MMutex);
std::lock_guard<mutex_class> Lock(MMutex);
MEvents.push_back(std::move(EventWeakPtr));
}

void queue_impl::addUSMEvent(event Event) {
std::lock_guard<mutex_class> Guard(MMutex);
MUSMEvents.push_back(std::move(Event));
void queue_impl::addUSMEvent(const event &Event) {
std::lock_guard<mutex_class> Lock(MMutex);
MUSMEvents.push_back(Event);
}

void *queue_impl::instrumentationProlog(const detail::code_location &CodeLoc,
string_class &Name, int32_t StreamID,
uint64_t &IId) {
void *TraceEvent = nullptr;
(void)CodeLoc;
(void)Name;
(void)StreamID;
(void)IId;
#ifdef XPTI_ENABLE_INSTRUMENTATION
xpti::trace_event_data_t *WaitEvent = nullptr;
if (!xptiTraceEnabled())
Expand Down Expand Up @@ -172,6 +175,10 @@ void *queue_impl::instrumentationProlog(const detail::code_location &CodeLoc,

void queue_impl::instrumentationEpilog(void *TelemetryEvent, string_class &Name,
int32_t StreamID, uint64_t IId) {
(void)TelemetryEvent;
(void)Name;
(void)StreamID;
(void)IId;
#ifdef XPTI_ENABLE_INSTRUMENTATION
if (!(xptiTraceEnabled() && TelemetryEvent))
return;
Expand All @@ -184,6 +191,7 @@ void queue_impl::instrumentationEpilog(void *TelemetryEvent, string_class &Name,
}

void queue_impl::wait(const detail::code_location &CodeLoc) {
(void)CodeLoc;
#ifdef XPTI_ENABLE_INSTRUMENTATION
void *TelemetryEvent = nullptr;
uint64_t IId;
Expand All @@ -192,24 +200,20 @@ void queue_impl::wait(const detail::code_location &CodeLoc) {
TelemetryEvent = instrumentationProlog(CodeLoc, Name, StreamID, IId);
#endif

std::vector<std::shared_ptr<event_impl>> Events;
vector_class<std::weak_ptr<event_impl>> Events;
vector_class<event> USMEvents;
{
std::lock_guard<mutex_class> Guard(MMutex);
for (std::weak_ptr<event_impl> &EventImplWeakPtr : MEvents)
if (std::shared_ptr<event_impl> EventImplPtr = EventImplWeakPtr.lock())
Events.push_back(EventImplPtr);
std::lock_guard<mutex_class> Lock(MMutex);
Events = std::move(MEvents);
USMEvents = std::move(MUSMEvents);
}

for (std::shared_ptr<event_impl> &Event : Events) {
Event->wait(Event);
}
for (std::weak_ptr<event_impl> &EventImplWeakPtr : Events)
if (std::shared_ptr<event_impl> EventImplPtr = EventImplWeakPtr.lock())
EventImplPtr->wait(EventImplPtr);

for (event &Event : MUSMEvents) {
for (event &Event : USMEvents)
Event.wait();
}

MEvents.clear();
MUSMEvents.clear();

#ifdef XPTI_ENABLE_INSTRUMENTATION
instrumentationEpilog(TelemetryEvent, Name, StreamID, IId);
Expand All @@ -222,9 +226,9 @@ void queue_impl::initHostTaskAndEventCallbackThreadPool() {

int Size = 1;

if (const char *val = std::getenv("SYCL_QUEUE_THREAD_POOL_SIZE"))
if (const char *Val = std::getenv("SYCL_QUEUE_THREAD_POOL_SIZE"))
try {
Size = std::stoi(val);
Size = std::stoi(Val);
} catch (...) {
throw invalid_parameter_error(
"Invalid value for SYCL_QUEUE_THREAD_POOL_SIZE environment variable",
Expand All @@ -241,9 +245,9 @@ void queue_impl::initHostTaskAndEventCallbackThreadPool() {
}

pi_native_handle queue_impl::getNative() const {
auto Plugin = getPlugin();
pi_native_handle Handle;
Plugin.call<PiApiKind::piextQueueGetNativeHandle>(MCommandQueue, &Handle);
const detail::plugin &Plugin = getPlugin();
pi_native_handle Handle{};
Plugin.call<PiApiKind::piextQueueGetNativeHandle>(MQueues[0], &Handle);
return Handle;
}

Expand Down
Loading