20
20
#include < functional>
21
21
#include < list>
22
22
#include < set>
23
+ #include < shared_mutex>
23
24
24
25
namespace sycl {
25
26
inline namespace _V1 {
@@ -167,6 +168,62 @@ class node_impl {
167
168
return nullptr ;
168
169
}
169
170
171
+ // / Tests is the caller is similar to Node
172
+ // / @return True if the two nodes are similar
173
+ bool isSimilar (std::shared_ptr<node_impl> Node) {
174
+ if (MCGType != Node->MCGType )
175
+ return false ;
176
+
177
+ if (MSuccessors.size () != Node->MSuccessors .size ())
178
+ return false ;
179
+
180
+ if (MPredecessors.size () != Node->MPredecessors .size ())
181
+ return false ;
182
+
183
+ if ((MCGType == sycl::detail::CG::CGTYPE::Kernel) &&
184
+ (Node->MCGType == sycl::detail::CG::CGTYPE::Kernel)) {
185
+ sycl::detail::CGExecKernel *ExecKernelA =
186
+ static_cast <sycl::detail::CGExecKernel *>(MCommandGroup.get ());
187
+ sycl::detail::CGExecKernel *ExecKernelB =
188
+ static_cast <sycl::detail::CGExecKernel *>(Node->MCommandGroup .get ());
189
+
190
+ if (ExecKernelA->MKernelName .compare (ExecKernelB->MKernelName ) != 0 )
191
+ return false ;
192
+ }
193
+ return true ;
194
+ }
195
+
196
+ // / Recursive traversal of successor nodes checking for
197
+ // / equivalent node successions in Node
198
+ // / @param Node pointer to the starting node for structure comparison
199
+ // / @return true is same structure found, false otherwise
200
+ bool checkNodeRecursive (std::shared_ptr<node_impl> Node) {
201
+ size_t FoundCnt = 0 ;
202
+ for (std::shared_ptr<node_impl> SuccA : MSuccessors) {
203
+ for (std::shared_ptr<node_impl> SuccB : Node->MSuccessors ) {
204
+ if (isSimilar (Node) && SuccA->checkNodeRecursive (SuccB)) {
205
+ FoundCnt++;
206
+ break ;
207
+ }
208
+ }
209
+ }
210
+ if (FoundCnt != MSuccessors.size ()) {
211
+ return false ;
212
+ }
213
+
214
+ return true ;
215
+ }
216
+
217
+ // / Recusively computes the number of successor nodes
218
+ // / @return number of successor nodes
219
+ size_t depthSearchCount () const {
220
+ size_t NumberOfNodes = 1 ;
221
+ for (const auto &Succ : MSuccessors) {
222
+ NumberOfNodes += Succ->depthSearchCount ();
223
+ }
224
+ return NumberOfNodes;
225
+ }
226
+
170
227
private:
171
228
// / Creates a copy of the node's CG by casting to it's actual type, then using
172
229
// / that to copy construct and create a new unique ptr from that copy.
@@ -180,17 +237,19 @@ class node_impl {
180
237
// / Implementation details of command_graph<modifiable>.
181
238
class graph_impl {
182
239
public:
240
+ using ReadLock = std::shared_lock<std::shared_mutex>;
241
+ using WriteLock = std::unique_lock<std::shared_mutex>;
242
+
243
+ // / Protects all the fields that can be changed by class' methods.
244
+ mutable std::shared_mutex MMutex;
245
+
183
246
// / Constructor.
184
247
// / @param SyclContext Context to use for graph.
185
248
// / @param SyclDevice Device to create nodes with.
186
249
graph_impl (const sycl::context &SyclContext, const sycl::device &SyclDevice)
187
250
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
188
251
MEventsMap (), MInorderQueueMap() {}
189
252
190
- // / Insert node into list of root nodes.
191
- // / @param Root Node to add to list of root nodes.
192
- void addRoot (const std::shared_ptr<node_impl> &Root);
193
-
194
253
// / Remove node from list of root nodes.
195
254
// / @param Root Node to remove from list of root nodes.
196
255
void removeRoot (const std::shared_ptr<node_impl> &Root);
@@ -264,6 +323,7 @@ class graph_impl {
264
323
// / @return Event associated with node.
265
324
std::shared_ptr<sycl::detail::event_impl>
266
325
getEventForNode (std::shared_ptr<node_impl> NodeImpl) const {
326
+ ReadLock Lock (MMutex);
267
327
if (auto EventImpl = std::find_if (
268
328
MEventsMap.begin (), MEventsMap.end (),
269
329
[NodeImpl](auto &it) { return it.second == NodeImpl; });
@@ -315,6 +375,95 @@ class graph_impl {
315
375
MInorderQueueMap[QueueWeakPtr] = Node;
316
376
}
317
377
378
+ // / Checks if the graph_impl of Graph has a similar structure to
379
+ // / the graph_impl of the caller.
380
+ // / Graphs are considered similar if they have same numbers of nodes
381
+ // / of the same type with similar predecessor and successor nodes (number and
382
+ // / type). Two nodes are considered similar if they have the same
383
+ // / command-group type. For command-groups of type "kernel", the "signature"
384
+ // / of the kernel is also compared (i.e. the name of the command-group).
385
+ // / @param Graph if reference to the graph to compare with.
386
+ // / @param DebugPrint if set to true throw exception with additional debug
387
+ // / information about the spotted graph differences.
388
+ // / @return true if the two graphs are similar, false otherwise
389
+ bool hasSimilarStructure (std::shared_ptr<detail::graph_impl> Graph,
390
+ bool DebugPrint = false ) const {
391
+ if (this == Graph.get ())
392
+ return true ;
393
+
394
+ if (MContext != Graph->MContext ) {
395
+ if (DebugPrint) {
396
+ throw sycl::exception (sycl::make_error_code (errc::invalid),
397
+ " MContext are not the same." );
398
+ }
399
+ return false ;
400
+ }
401
+
402
+ if (MDevice != Graph->MDevice ) {
403
+ if (DebugPrint) {
404
+ throw sycl::exception (sycl::make_error_code (errc::invalid),
405
+ " MDevice are not the same." );
406
+ }
407
+ return false ;
408
+ }
409
+
410
+ if (MEventsMap.size () != Graph->MEventsMap .size ()) {
411
+ if (DebugPrint) {
412
+ throw sycl::exception (sycl::make_error_code (errc::invalid),
413
+ " MEventsMap sizes are not the same." );
414
+ }
415
+ return false ;
416
+ }
417
+
418
+ if (MInorderQueueMap.size () != Graph->MInorderQueueMap .size ()) {
419
+ if (DebugPrint) {
420
+ throw sycl::exception (sycl::make_error_code (errc::invalid),
421
+ " MInorderQueueMap sizes are not the same." );
422
+ }
423
+ return false ;
424
+ }
425
+
426
+ if (MRoots.size () != Graph->MRoots .size ()) {
427
+ if (DebugPrint) {
428
+ throw sycl::exception (sycl::make_error_code (errc::invalid),
429
+ " MRoots sizes are not the same." );
430
+ }
431
+ return false ;
432
+ }
433
+
434
+ size_t RootsFound = 0 ;
435
+ for (std::shared_ptr<node_impl> NodeA : MRoots) {
436
+ for (std::shared_ptr<node_impl> NodeB : Graph->MRoots ) {
437
+ if (NodeA->isSimilar (NodeB)) {
438
+ if (NodeA->checkNodeRecursive (NodeB)) {
439
+ RootsFound++;
440
+ break ;
441
+ }
442
+ }
443
+ }
444
+ }
445
+
446
+ if (RootsFound != MRoots.size ()) {
447
+ if (DebugPrint) {
448
+ throw sycl::exception (sycl::make_error_code (errc::invalid),
449
+ " Root Nodes do NOT match." );
450
+ }
451
+ return false ;
452
+ }
453
+
454
+ return true ;
455
+ }
456
+
457
+ // Returns the number of nodes in the Graph
458
+ // @return Number of nodes in the Graph
459
+ size_t getNumberOfNodes () const {
460
+ size_t NumberOfNodes = 0 ;
461
+ for (const auto &Node : MRoots) {
462
+ NumberOfNodes += Node->depthSearchCount ();
463
+ }
464
+ return NumberOfNodes;
465
+ }
466
+
318
467
private:
319
468
// / Context associated with this graph.
320
469
sycl::context MContext;
@@ -333,11 +482,21 @@ class graph_impl {
333
482
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
334
483
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
335
484
MInorderQueueMap;
485
+
486
+ // / Insert node into list of root nodes.
487
+ // / @param Root Node to add to list of root nodes.
488
+ void addRoot (const std::shared_ptr<node_impl> &Root);
336
489
};
337
490
338
491
// / Class representing the implementation of command_graph<executable>.
339
492
class exec_graph_impl {
340
493
public:
494
+ using ReadLock = std::shared_lock<std::shared_mutex>;
495
+ using WriteLock = std::unique_lock<std::shared_mutex>;
496
+
497
+ // / Protects all the fields that can be changed by class' methods.
498
+ mutable std::shared_mutex MMutex;
499
+
341
500
// / Constructor.
342
501
// / @param Context Context to create graph with.
343
502
// / @param GraphImpl Modifiable graph implementation to create with.
@@ -413,6 +572,10 @@ class exec_graph_impl {
413
572
std::list<std::shared_ptr<node_impl>> MSchedule;
414
573
// / Pointer to the modifiable graph impl associated with this executable
415
574
// / graph.
575
+ // / Thread-safe implementation note: in the current implementation
576
+ // / multiple exec_graph_impl can reference the same graph_impl object.
577
+ // / This specificity must be taken into account when trying to lock
578
+ // / the graph_impl mutex from an exec_graph_impl to avoid deadlock.
416
579
std::shared_ptr<graph_impl> MGraphImpl;
417
580
// / Map of devices to command buffers.
418
581
std::unordered_map<sycl::device, sycl::detail::pi::PiExtCommandBuffer>
0 commit comments