Skip to content

[NFCI][SYCL][Graph] Cleanup after enable_shared_from_this for queue_impl #18748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl/source/detail/async_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
// If this is being recorded from an in-order queue we need to get the last
// in-order node if any, since this will later become a dependency of the
// node being processed here.
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue);
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue.get());
LastInOrderNode) {
DepNodes.push_back(LastInOrderNode);
}
Expand Down
73 changes: 49 additions & 24 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
return NodeImpl;
}

void graph_impl::addQueue(sycl::detail::queue_impl &RecordingQueue) {
MRecordingQueues.insert(RecordingQueue.weak_from_this());
}

void graph_impl::removeQueue(sycl::detail::queue_impl &RecordingQueue) {
MRecordingQueues.erase(RecordingQueue.weak_from_this());
}

bool graph_impl::clearQueues() {
bool AnyQueuesCleared = false;
for (auto &Queue : MRecordingQueues) {
Expand Down Expand Up @@ -689,6 +697,24 @@ bool graph_impl::checkForCycles() {
return CycleFound;
}

std::shared_ptr<node_impl>
graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
if (!Queue) {
assert(0 ==
MInorderQueueMap.count(std::weak_ptr<sycl::detail::queue_impl>{}));
return {};
}
if (0 == MInorderQueueMap.count(Queue->weak_from_this())) {
return {};
}
return MInorderQueueMap[Queue->weak_from_this()];
}

void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
std::shared_ptr<node_impl> Node) {
MInorderQueueMap[Queue.weak_from_this()] = Node;
}

void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
std::shared_ptr<node_impl> Dest) {
throwIfGraphRecordingQueue("make_edge()");
Expand Down Expand Up @@ -769,11 +795,10 @@ std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
return Events;
}

void graph_impl::beginRecording(
const std::shared_ptr<sycl::detail::queue_impl> &Queue) {
void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) {
graph_impl::WriteLock Lock(MMutex);
if (!Queue->hasCommandGraph()) {
Queue->setCommandGraph(shared_from_this());
if (!Queue.hasCommandGraph()) {
Queue.setCommandGraph(shared_from_this());
addQueue(Queue);
}
}
Expand Down Expand Up @@ -1003,7 +1028,7 @@ exec_graph_impl::~exec_graph_impl() {
}

sycl::event
exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,
sycl::detail::CG::StorageInitHelper CGData) {
WriteLock Lock(MMutex);

Expand All @@ -1012,8 +1037,9 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
PartitionsExecutionEvents;

auto CreateNewEvent([&]() {
auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
NewEvent->setContextImpl(Queue->getContextImplPtr());
auto NewEvent =
std::make_shared<sycl::detail::event_impl>(Queue.shared_from_this());
NewEvent->setContextImpl(Queue.getContextImplPtr());
NewEvent->setStateIncomplete();
return NewEvent;
});
Expand All @@ -1035,7 +1061,7 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
CGData.MEvents.push_back(PartitionsExecutionEvents[DepPartition]);
}

auto CommandBuffer = CurrentPartition->MCommandBuffers[Queue->get_device()];
auto CommandBuffer = CurrentPartition->MCommandBuffers[Queue.get_device()];

if (CommandBuffer) {
for (std::vector<sycl::detail::EventImplPtr>::iterator It =
Expand Down Expand Up @@ -1073,10 +1099,10 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
if (CGData.MRequirements.empty() && CGData.MEvents.empty()) {
NewEvent->setSubmissionTime();
ur_result_t Res =
Queue->getAdapter()
Queue.getAdapter()
->call_nocheck<
sycl::detail::UrApiKind::urEnqueueCommandBufferExp>(
Queue->getHandleRef(), CommandBuffer, 0, nullptr, &UREvent);
Queue.getHandleRef(), CommandBuffer, 0, nullptr, &UREvent);
NewEvent->setHandle(UREvent);
if (Res == UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES) {
throw sycl::exception(
Expand All @@ -1096,7 +1122,8 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
CommandBuffer, nullptr, std::move(CGData));

NewEvent = sycl::detail::Scheduler::getInstance().addCG(
std::move(CommandGroup), Queue, /*EventNeeded=*/true);
std::move(CommandGroup), Queue.shared_from_this(),
/*EventNeeded=*/true);
}
NewEvent->setEventFromSubmittedExecCommandBuffer(true);
} else if ((CurrentPartition->MSchedule.size() > 0) &&
Expand All @@ -1112,10 +1139,11 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
// In case of graph, this queue may differ from the actual execution
// queue. We therefore overload this Queue before submitting the task.
static_cast<sycl::detail::CGHostTask &>(*NodeImpl->MCommandGroup.get())
.MQueue = Queue;
.MQueue = Queue.shared_from_this();

NewEvent = sycl::detail::Scheduler::getInstance().addCG(
NodeImpl->getCGCopy(), Queue, /*EventNeeded=*/true);
NodeImpl->getCGCopy(), Queue.shared_from_this(),
/*EventNeeded=*/true);
}
PartitionsExecutionEvents[CurrentPartition] = NewEvent;
}
Expand Down Expand Up @@ -1844,21 +1872,20 @@ void modifiable_command_graph::begin_recording(
// related to graph at all.
checkGraphPropertiesAndThrow(PropList);

auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
assert(QueueImpl);
queue_impl &QueueImpl = *sycl::detail::getSyclObjImpl(RecordingQueue);

if (QueueImpl->hasCommandGraph()) {
if (QueueImpl.hasCommandGraph()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"begin_recording cannot be called for a queue which "
"is already in the recording state.");
}

if (QueueImpl->get_context() != impl->getContext()) {
if (QueueImpl.get_context() != impl->getContext()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"begin_recording called for a queue whose context "
"differs from the graph context.");
}
if (QueueImpl->get_device() != impl->getDevice()) {
if (QueueImpl.get_device() != impl->getDevice()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"begin_recording called for a queue whose device "
"differs from the graph device.");
Expand All @@ -1881,15 +1908,13 @@ void modifiable_command_graph::end_recording() {
}

void modifiable_command_graph::end_recording(queue &RecordingQueue) {
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
if (!QueueImpl)
return;
if (QueueImpl->getCommandGraph() == impl) {
QueueImpl->setCommandGraph(nullptr);
queue_impl &QueueImpl = *sycl::detail::getSyclObjImpl(RecordingQueue);
if (QueueImpl.getCommandGraph() == impl) {
QueueImpl.setCommandGraph(nullptr);
graph_impl::WriteLock Lock(impl->MMutex);
impl->removeQueue(QueueImpl);
}
if (QueueImpl->hasCommandGraph())
if (QueueImpl.hasCommandGraph())
throw sycl::exception(sycl::make_error_code(errc::invalid),
"end_recording called for a queue which is recording "
"to a different graph.");
Expand Down
29 changes: 7 additions & 22 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,18 +878,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// Add a queue to the set of queues which are currently recording to this
/// graph.
/// @param RecordingQueue Queue to add to set.
void
addQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
MRecordingQueues.insert(RecordingQueue);
}
void addQueue(sycl::detail::queue_impl &RecordingQueue);
Copy link
Contributor Author

@aelovikov-intel aelovikov-intel Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to move this and the ones below into the graph_impl.cpp because queue_impl is incomplete at this point. queue_impl.hpp/graph_impl.hpp include each other (which is totally wrong but not the subject of this PR) and untangling that would be a task that I'm not willing to take on, at least not now.

Another alternative would be to make these templates:

template <typename QueueImplTy = queue_impl>
void foo(QueueImplTy &Q)

to make weak_from_this() template-type-dependent and delay the requirement for a complete type until foo is instantiated (vs defined as of now).

Accepting std::weak_ptr<queue_impl> would also work (and some pre-existing methods do just that already) but I find it lacking in expressiveness - it doesn't communicate expectations about nullptr possibility of if the argument can be in already expired state.


/// Remove a queue from the set of queues which are currently recording to
/// this graph.
/// @param RecordingQueue Queue to remove from set.
void
removeQueue(const std::shared_ptr<sycl::detail::queue_impl> &RecordingQueue) {
MRecordingQueues.erase(RecordingQueue);
}
void removeQueue(sycl::detail::queue_impl &RecordingQueue);

/// Remove all queues which are recording to this graph, also sets all queues
/// cleared back to the executing state.
Expand Down Expand Up @@ -1001,22 +995,13 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @return Last node in this graph added from \p Queue recording, or empty
/// shared pointer if none.
std::shared_ptr<node_impl>
getLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue) {
std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
if (0 == MInorderQueueMap.count(QueueWeakPtr)) {
return {};
}
return MInorderQueueMap[QueueWeakPtr];
}
getLastInorderNode(sycl::detail::queue_impl *Queue);

/// Track the last node added to this graph from an in-order queue.
/// @param Queue In-order queue to register \p Node for.
/// @param Node Last node that was added to this graph from \p Queue.
void setLastInorderNode(std::shared_ptr<sycl::detail::queue_impl> Queue,
std::shared_ptr<node_impl> Node) {
std::weak_ptr<sycl::detail::queue_impl> QueueWeakPtr(Queue);
MInorderQueueMap[QueueWeakPtr] = Node;
}
void setLastInorderNode(sycl::detail::queue_impl &Queue,
std::shared_ptr<node_impl> Node);

/// Prints the contents of the graph to a text file in DOT format.
/// @param FilePath Path to the output file.
Expand Down Expand Up @@ -1176,7 +1161,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// Sets the Queue state to queue_state::recording. Adds the queue to the list
/// of recording queues associated with this graph.
/// @param[in] Queue The queue to be recorded from.
void beginRecording(const std::shared_ptr<sycl::detail::queue_impl> &Queue);
void beginRecording(sycl::detail::queue_impl &Queue);

/// Store the last barrier node that was submitted to the queue.
/// @param[in] Queue The queue the barrier was recorded from.
Expand Down Expand Up @@ -1346,7 +1331,7 @@ class exec_graph_impl {
/// @param Queue Command-queue to schedule execution on.
/// @param CGData Command-group data provided by the sycl::handler
/// @return Event associated with the execution of the graph.
sycl::event enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
sycl::event enqueue(sycl::detail::queue_impl &Queue,
sycl::detail::CG::StorageInitHelper CGData);

/// Turns the internal graph representation into UR command-buffers for a
Expand Down
9 changes: 5 additions & 4 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ event handler::finalize() {

} else {
event GraphCompletionEvent =
impl->MExecGraph->enqueue(MQueue, std::move(impl->CGData));
impl->MExecGraph->enqueue(impl->get_queue(), std::move(impl->CGData));

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
MLastEvent = getSyclObjImpl(GraphCompletionEvent);
Expand Down Expand Up @@ -870,15 +870,16 @@ event handler::finalize() {
// node can set it as a predecessor.
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
Deps;
if (auto DependentNode = GraphImpl->getLastInorderNode(MQueue)) {
if (auto DependentNode =
GraphImpl->getLastInorderNode(impl->get_queue_or_null())) {
Deps.push_back(std::move(DependentNode));
}
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);

// If we are recording an in-order queue remember the new node, so it
// can be used as a dependency for any more nodes recorded from this
// queue.
GraphImpl->setLastInorderNode(MQueue, NodeImpl);
GraphImpl->setLastInorderNode(*MQueue, NodeImpl);
} else {
auto LastBarrierRecordedFromQueue = GraphImpl->getBarrierDep(MQueue);
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
Expand Down Expand Up @@ -1988,7 +1989,7 @@ void handler::depends_on(const detail::EventImplPtr &EventImpl) {
// we need to set it to recording (implements the transitive queue recording
// feature).
if (!QueueGraph) {
EventGraph->beginRecording(MQueue);
EventGraph->beginRecording(impl->get_queue());
}
}

Expand Down
Loading