Skip to content

Commit c8c64a6

Browse files
EwanCreblejulianmiBensuo
authored
[SYCL][Graph] Make SYCL-Graph functions thread-safe (#10778)
This PR makes the new APIs defined by [sycl_ext_oneapi_graph](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_graph.asciidoc) thread safe. ## Authors Co-authored-by: Pablo Reble <pablo.reble@intel.com> Co-authored-by: Julian Miller <julian.miller@intel.com> Co-authored-by: Ben Tracy <ben.tracy@codeplay.com> Co-authored-by: Ewan Crawford <ewan@codeplay.com> Co-authored-by: Maxime France-Pillois <maxime.francepillois@codeplay.com>
1 parent fdee56c commit c8c64a6

File tree

12 files changed

+697
-8
lines changed

12 files changed

+697
-8
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,8 @@ void exec_graph_impl::createCommandBuffers(sycl::device Device) {
345345
}
346346

347347
exec_graph_impl::~exec_graph_impl() {
348+
WriteLock LockImpl(MGraphImpl->MMutex);
349+
348350
// clear all recording queue if not done before (no call to end_recording)
349351
MGraphImpl->clearQueues();
350352

@@ -370,6 +372,8 @@ exec_graph_impl::~exec_graph_impl() {
370372
sycl::event
371373
exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
372374
sycl::detail::CG::StorageInitHelper CGData) {
375+
WriteLock Lock(MMutex);
376+
373377
auto CreateNewEvent([&]() {
374378
auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
375379
NewEvent->setContextImpl(Queue->getContextImplPtr());
@@ -483,6 +487,7 @@ node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
483487
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
484488
}
485489

490+
graph_impl::WriteLock Lock(impl->MMutex);
486491
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
487492
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
488493
}
@@ -494,6 +499,7 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
494499
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
495500
}
496501

502+
graph_impl::WriteLock Lock(impl->MMutex);
497503
std::shared_ptr<detail::node_impl> NodeImpl =
498504
impl->add(impl, CGF, {}, DepImpls);
499505
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
@@ -505,13 +511,17 @@ void modifiable_command_graph::make_edge(node &Src, node &Dest) {
505511
std::shared_ptr<detail::node_impl> ReceiverImpl =
506512
sycl::detail::getSyclObjImpl(Dest);
507513

514+
graph_impl::WriteLock Lock(impl->MMutex);
508515
SenderImpl->registerSuccessor(ReceiverImpl,
509516
SenderImpl); // register successor
510517
impl->removeRoot(ReceiverImpl); // remove receiver from root node list
511518
}
512519

513520
command_graph<graph_state::executable>
514521
modifiable_command_graph::finalize(const sycl::property_list &) const {
522+
// Graph is read and written in this scope so we lock
523+
// this graph with full priviledges.
524+
graph_impl::WriteLock Lock(impl->MMutex);
515525
return command_graph<graph_state::executable>{this->impl,
516526
this->impl->getContext()};
517527
}
@@ -549,6 +559,7 @@ bool modifiable_command_graph::begin_recording(queue &RecordingQueue) {
549559

550560
if (QueueImpl->getCommandGraph() == nullptr) {
551561
QueueImpl->setCommandGraph(impl);
562+
graph_impl::WriteLock Lock(impl->MMutex);
552563
impl->addQueue(QueueImpl);
553564
return true;
554565
}
@@ -570,12 +581,16 @@ bool modifiable_command_graph::begin_recording(
570581
return QueueStateChanged;
571582
}
572583

573-
bool modifiable_command_graph::end_recording() { return impl->clearQueues(); }
584+
bool modifiable_command_graph::end_recording() {
585+
graph_impl::WriteLock Lock(impl->MMutex);
586+
return impl->clearQueues();
587+
}
574588

575589
bool modifiable_command_graph::end_recording(queue &RecordingQueue) {
576590
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
577591
if (QueueImpl && QueueImpl->getCommandGraph() == impl) {
578592
QueueImpl->setCommandGraph(nullptr);
593+
graph_impl::WriteLock Lock(impl->MMutex);
579594
impl->removeQueue(QueueImpl);
580595
return true;
581596
}

sycl/source/detail/graph_impl.hpp

Lines changed: 167 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <functional>
2121
#include <list>
2222
#include <set>
23+
#include <shared_mutex>
2324

2425
namespace sycl {
2526
inline namespace _V1 {
@@ -167,6 +168,62 @@ class node_impl {
167168
return nullptr;
168169
}
169170

171+
/// Tests is the caller is similar to Node
172+
/// @return True if the two nodes are similar
173+
bool isSimilar(std::shared_ptr<node_impl> Node) {
174+
if (MCGType != Node->MCGType)
175+
return false;
176+
177+
if (MSuccessors.size() != Node->MSuccessors.size())
178+
return false;
179+
180+
if (MPredecessors.size() != Node->MPredecessors.size())
181+
return false;
182+
183+
if ((MCGType == sycl::detail::CG::CGTYPE::Kernel) &&
184+
(Node->MCGType == sycl::detail::CG::CGTYPE::Kernel)) {
185+
sycl::detail::CGExecKernel *ExecKernelA =
186+
static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
187+
sycl::detail::CGExecKernel *ExecKernelB =
188+
static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get());
189+
190+
if (ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) != 0)
191+
return false;
192+
}
193+
return true;
194+
}
195+
196+
/// Recursive traversal of successor nodes checking for
197+
/// equivalent node successions in Node
198+
/// @param Node pointer to the starting node for structure comparison
199+
/// @return true is same structure found, false otherwise
200+
bool checkNodeRecursive(std::shared_ptr<node_impl> Node) {
201+
size_t FoundCnt = 0;
202+
for (std::shared_ptr<node_impl> SuccA : MSuccessors) {
203+
for (std::shared_ptr<node_impl> SuccB : Node->MSuccessors) {
204+
if (isSimilar(Node) && SuccA->checkNodeRecursive(SuccB)) {
205+
FoundCnt++;
206+
break;
207+
}
208+
}
209+
}
210+
if (FoundCnt != MSuccessors.size()) {
211+
return false;
212+
}
213+
214+
return true;
215+
}
216+
217+
/// Recusively computes the number of successor nodes
218+
/// @return number of successor nodes
219+
size_t depthSearchCount() const {
220+
size_t NumberOfNodes = 1;
221+
for (const auto &Succ : MSuccessors) {
222+
NumberOfNodes += Succ->depthSearchCount();
223+
}
224+
return NumberOfNodes;
225+
}
226+
170227
private:
171228
/// Creates a copy of the node's CG by casting to it's actual type, then using
172229
/// that to copy construct and create a new unique ptr from that copy.
@@ -180,17 +237,19 @@ class node_impl {
180237
/// Implementation details of command_graph<modifiable>.
181238
class graph_impl {
182239
public:
240+
using ReadLock = std::shared_lock<std::shared_mutex>;
241+
using WriteLock = std::unique_lock<std::shared_mutex>;
242+
243+
/// Protects all the fields that can be changed by class' methods.
244+
mutable std::shared_mutex MMutex;
245+
183246
/// Constructor.
184247
/// @param SyclContext Context to use for graph.
185248
/// @param SyclDevice Device to create nodes with.
186249
graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice)
187250
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
188251
MEventsMap(), MInorderQueueMap() {}
189252

190-
/// Insert node into list of root nodes.
191-
/// @param Root Node to add to list of root nodes.
192-
void addRoot(const std::shared_ptr<node_impl> &Root);
193-
194253
/// Remove node from list of root nodes.
195254
/// @param Root Node to remove from list of root nodes.
196255
void removeRoot(const std::shared_ptr<node_impl> &Root);
@@ -264,6 +323,7 @@ class graph_impl {
264323
/// @return Event associated with node.
265324
std::shared_ptr<sycl::detail::event_impl>
266325
getEventForNode(std::shared_ptr<node_impl> NodeImpl) const {
326+
ReadLock Lock(MMutex);
267327
if (auto EventImpl = std::find_if(
268328
MEventsMap.begin(), MEventsMap.end(),
269329
[NodeImpl](auto &it) { return it.second == NodeImpl; });
@@ -315,6 +375,95 @@ class graph_impl {
315375
MInorderQueueMap[QueueWeakPtr] = Node;
316376
}
317377

378+
/// Checks if the graph_impl of Graph has a similar structure to
379+
/// the graph_impl of the caller.
380+
/// Graphs are considered similar if they have same numbers of nodes
381+
/// of the same type with similar predecessor and successor nodes (number and
382+
/// type). Two nodes are considered similar if they have the same
383+
/// command-group type. For command-groups of type "kernel", the "signature"
384+
/// of the kernel is also compared (i.e. the name of the command-group).
385+
/// @param Graph if reference to the graph to compare with.
386+
/// @param DebugPrint if set to true throw exception with additional debug
387+
/// information about the spotted graph differences.
388+
/// @return true if the two graphs are similar, false otherwise
389+
bool hasSimilarStructure(std::shared_ptr<detail::graph_impl> Graph,
390+
bool DebugPrint = false) const {
391+
if (this == Graph.get())
392+
return true;
393+
394+
if (MContext != Graph->MContext) {
395+
if (DebugPrint) {
396+
throw sycl::exception(sycl::make_error_code(errc::invalid),
397+
"MContext are not the same.");
398+
}
399+
return false;
400+
}
401+
402+
if (MDevice != Graph->MDevice) {
403+
if (DebugPrint) {
404+
throw sycl::exception(sycl::make_error_code(errc::invalid),
405+
"MDevice are not the same.");
406+
}
407+
return false;
408+
}
409+
410+
if (MEventsMap.size() != Graph->MEventsMap.size()) {
411+
if (DebugPrint) {
412+
throw sycl::exception(sycl::make_error_code(errc::invalid),
413+
"MEventsMap sizes are not the same.");
414+
}
415+
return false;
416+
}
417+
418+
if (MInorderQueueMap.size() != Graph->MInorderQueueMap.size()) {
419+
if (DebugPrint) {
420+
throw sycl::exception(sycl::make_error_code(errc::invalid),
421+
"MInorderQueueMap sizes are not the same.");
422+
}
423+
return false;
424+
}
425+
426+
if (MRoots.size() != Graph->MRoots.size()) {
427+
if (DebugPrint) {
428+
throw sycl::exception(sycl::make_error_code(errc::invalid),
429+
"MRoots sizes are not the same.");
430+
}
431+
return false;
432+
}
433+
434+
size_t RootsFound = 0;
435+
for (std::shared_ptr<node_impl> NodeA : MRoots) {
436+
for (std::shared_ptr<node_impl> NodeB : Graph->MRoots) {
437+
if (NodeA->isSimilar(NodeB)) {
438+
if (NodeA->checkNodeRecursive(NodeB)) {
439+
RootsFound++;
440+
break;
441+
}
442+
}
443+
}
444+
}
445+
446+
if (RootsFound != MRoots.size()) {
447+
if (DebugPrint) {
448+
throw sycl::exception(sycl::make_error_code(errc::invalid),
449+
"Root Nodes do NOT match.");
450+
}
451+
return false;
452+
}
453+
454+
return true;
455+
}
456+
457+
// Returns the number of nodes in the Graph
458+
// @return Number of nodes in the Graph
459+
size_t getNumberOfNodes() const {
460+
size_t NumberOfNodes = 0;
461+
for (const auto &Node : MRoots) {
462+
NumberOfNodes += Node->depthSearchCount();
463+
}
464+
return NumberOfNodes;
465+
}
466+
318467
private:
319468
/// Context associated with this graph.
320469
sycl::context MContext;
@@ -333,11 +482,21 @@ class graph_impl {
333482
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
334483
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
335484
MInorderQueueMap;
485+
486+
/// Insert node into list of root nodes.
487+
/// @param Root Node to add to list of root nodes.
488+
void addRoot(const std::shared_ptr<node_impl> &Root);
336489
};
337490

338491
/// Class representing the implementation of command_graph<executable>.
339492
class exec_graph_impl {
340493
public:
494+
using ReadLock = std::shared_lock<std::shared_mutex>;
495+
using WriteLock = std::unique_lock<std::shared_mutex>;
496+
497+
/// Protects all the fields that can be changed by class' methods.
498+
mutable std::shared_mutex MMutex;
499+
341500
/// Constructor.
342501
/// @param Context Context to create graph with.
343502
/// @param GraphImpl Modifiable graph implementation to create with.
@@ -413,6 +572,10 @@ class exec_graph_impl {
413572
std::list<std::shared_ptr<node_impl>> MSchedule;
414573
/// Pointer to the modifiable graph impl associated with this executable
415574
/// graph.
575+
/// Thread-safe implementation note: in the current implementation
576+
/// multiple exec_graph_impl can reference the same graph_impl object.
577+
/// This specificity must be taken into account when trying to lock
578+
/// the graph_impl mutex from an exec_graph_impl to avoid deadlock.
416579
std::shared_ptr<graph_impl> MGraphImpl;
417580
/// Map of devices to command buffers.
418581
std::unordered_map<sycl::device, sycl::detail::pi::PiExtCommandBuffer>

sycl/source/detail/queue_impl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ class queue_impl {
684684

685685
void setCommandGraph(
686686
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> Graph) {
687+
std::lock_guard<std::mutex> Lock(MMutex);
687688
MGraph = Graph;
688689
}
689690

sycl/source/handler.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,11 @@ event handler::finalize() {
456456
std::shared_ptr<ext::oneapi::experimental::detail::node_impl> NodeImpl =
457457
nullptr;
458458

459+
// GraphImpl is read and written in this scope so we lock this graph
460+
// with full priviledges.
461+
ext::oneapi::experimental::detail::graph_impl::WriteLock Lock(
462+
GraphImpl->MMutex);
463+
459464
// Create a new node in the graph representing this command-group
460465
if (MQueue->isInOrder()) {
461466
// In-order queues create implicit linear dependencies between nodes.
@@ -1332,15 +1337,28 @@ void handler::ext_oneapi_graph(
13321337
Graph) {
13331338
MCGType = detail::CG::ExecCommandBuffer;
13341339
auto GraphImpl = detail::getSyclObjImpl(Graph);
1340+
// GraphImpl is only read in this scope so we lock this graph for read only
1341+
ext::oneapi::experimental::detail::graph_impl::ReadLock Lock(
1342+
GraphImpl->MMutex);
1343+
13351344
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> ParentGraph;
13361345
if (MQueue) {
13371346
ParentGraph = MQueue->getCommandGraph();
13381347
} else {
13391348
ParentGraph = MGraph;
13401349
}
13411350

1351+
ext::oneapi::experimental::detail::graph_impl::WriteLock ParentLock;
13421352
// If a parent graph is set that means we are adding or recording a subgraph
13431353
if (ParentGraph) {
1354+
// ParentGraph is read and written in this scope so we lock this graph
1355+
// with full priviledges.
1356+
// We only lock for Record&Replay API because the graph has already been
1357+
// lock if this function was called from the explicit API function add
1358+
if (MQueue) {
1359+
ParentLock = ext::oneapi::experimental::detail::graph_impl::WriteLock(
1360+
ParentGraph->MMutex);
1361+
}
13441362
// Store the node representing the subgraph in the handler so that we can
13451363
// return it to the user later.
13461364
MSubgraphNode = ParentGraph->addSubgraphNodes(GraphImpl->getSchedule());

sycl/test-e2e/Graph/Explicit/basic_usm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// REQUIRES: level_zero, gpu
2-
// RUN: %{build} -o %t.out
2+
// RUN: %{build_pthread_inc} -o %t.out
33
// RUN: %{run} %t.out
44
// Extra run to check for leaks in Level Zero using ZE_DEBUG
55
// RUN: %if ext_oneapi_level_zero %{env ZE_DEBUG=4 %{run} %t.out 2>&1 | FileCheck %s %}

sycl/test-e2e/Graph/Inputs/basic_usm.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
// and submission of the graph.
33

44
#include "../graph_common.hpp"
5+
#include <thread>
56

67
int main() {
78
queue Queue;
89

910
using T = int;
1011

12+
const unsigned NumThreads = std::thread::hardware_concurrency();
1113
std::vector<T> DataA(Size), DataB(Size), DataC(Size);
1214

1315
std::iota(DataA.begin(), DataA.end(), 1);
@@ -32,8 +34,15 @@ int main() {
3234
// Add commands to graph
3335
add_nodes(Graph, Queue, Size, PtrA, PtrB, PtrC);
3436

37+
Barrier SyncPoint{NumThreads};
38+
3539
auto GraphExec = Graph.finalize();
3640

41+
auto SubmitGraph = [&]() {
42+
SyncPoint.wait();
43+
Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(GraphExec); });
44+
};
45+
3746
event Event;
3847
for (unsigned n = 0; n < Iterations; n++) {
3948
Event = Queue.submit([&](handler &CGH) {

0 commit comments

Comments
 (0)