Skip to content

Commit 5337a8a

Browse files
authored
[SYCL][Graph] Add node and graph queries for mixed usage (#12366)
This PR adds queries to both nodes and modifiable graphs which enable better mixed usage of both the explicit and record & replay APIs in a single program. It also reworks how subgraphs are handled: previously nodes were merged into the modifiable graph, but this would pose a problem for users querying the graph since they would not see a single subgraph node, and this merging behaviour was an implementation detail. This has been changed so that now subgraph nodes are only merged in the executable graph, and are stored as a single node of type `subgraph` in the modifiable graph. As a consequence of this change all nodes are now also copied when making the executable graph, where previously they were not. - Reworked how subgraphs are handled - Add graph and node queries to the SYCL-Graph spec - Implement graph and node queries - New node_type enum - Explicit nodes now also have associated events (fixes mixed usage issue) - New tests for queries - Update ABI symbols
1 parent 24ce45c commit 5337a8a

File tree

11 files changed

+805
-263
lines changed

11 files changed

+805
-263
lines changed

sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,19 @@ enum class graph_support_level {
315315
emulated
316316
};
317317
318+
enum class node_type {
319+
empty,
320+
subgraph,
321+
kernel,
322+
memcpy,
323+
memset,
324+
memfill,
325+
prefetch,
326+
memadvise,
327+
ext_oneapi_barrier,
328+
host_task,
329+
};
330+
318331
namespace property {
319332
320333
namespace graph {
@@ -355,7 +368,18 @@ struct graphs_support;
355368
} // namespace device
356369
} // namespace info
357370
358-
class node {};
371+
class node {
372+
public:
373+
node() = delete;
374+
375+
node_type get_type() const;
376+
377+
std::vector<node> get_predecessors() const;
378+
379+
std::vector<node> get_successors() const;
380+
381+
static node get_node_from_event(event nodeEvent);
382+
};
359383
360384
// State of a graph
361385
enum class graph_state {
@@ -394,6 +418,9 @@ public:
394418
void make_edge(node& src, node& dest);
395419
396420
void print_graph(std::string path, bool verbose = false) const;
421+
422+
std::vector<node> get_nodes() const;
423+
std::vector<node> get_root_nodes() const;
397424
};
398425
399426
template<>
@@ -467,12 +494,56 @@ edges.
467494

468495
The `node` class provides the {crs}[common reference semantics].
469496

497+
==== Node Member Functions
498+
499+
Table {counter: tableNumber}. Member functions of the `node` class.
500+
[cols="2a,a"]
501+
|===
502+
|Member Function|Description
503+
504+
|
470505
[source,c++]
471506
----
472-
namespace sycl::ext::oneapi::experimental {
473-
class node {};
474-
}
507+
node_type get_type() const;
508+
----
509+
|Returns a value representing the type of command this node represents.
510+
511+
|
512+
[source,c++]
513+
----
514+
std::vector<node> get_predecessors() const;
515+
----
516+
|Returns a list of the predecessor nodes which this node directly depends on.
517+
518+
|
519+
[source,c++]
520+
----
521+
std::vector<node> get_successors() const;
475522
----
523+
|Returns a list of the successor nodes which directly depend on this node.
524+
525+
|
526+
[source,c++]
527+
----
528+
static node get_node_from_event(event nodeEvent);
529+
----
530+
|Finds the node associated with an event created from a submission to a queue
531+
in the recording state.
532+
533+
Parameters:
534+
535+
* `nodeEvent` - Event returned from a submission to a queue in the recording
536+
state.
537+
538+
Returns: Graph node that was created when the command that returned
539+
`nodeEvent` was submitted.
540+
541+
Exceptions:
542+
543+
* Throws with error code `invalid` if `nodeEvent` is not associated with a
544+
graph node.
545+
546+
|===
476547

477548
==== Depends-On Property
478549

@@ -809,6 +880,21 @@ Exceptions:
809880
* Throws synchronously with error code `invalid` if the path is invalid or
810881
the file extension is not supported or if the write operation failed.
811882

883+
|
884+
[source,c++]
885+
----
886+
std::vector<node> get_nodes() const;
887+
----
888+
|Returns a list of all the nodes present in the graph in the order that they
889+
were added.
890+
891+
|
892+
[source,c++]
893+
----
894+
std::vector<node> get_root_nodes() const;
895+
----
896+
|Returns a list of all nodes in the graph which have no dependencies.
897+
812898
|===
813899

814900
Table {counter: tableNumber}. Member functions of the `command_graph` class for queue recording.

sycl/include/sycl/detail/cg.hpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ inline namespace _V1 {
3333
// Forward declarations
3434
class queue;
3535

36+
namespace ext::oneapi::experimental::detail {
37+
class exec_graph_impl;
38+
}
39+
3640
namespace detail {
3741

3842
class event_impl;
@@ -573,11 +577,16 @@ class CGSemaphoreSignal : public CG {
573577
class CGExecCommandBuffer : public CG {
574578
public:
575579
sycl::detail::pi::PiExtCommandBuffer MCommandBuffer;
576-
577-
CGExecCommandBuffer(const sycl::detail::pi::PiExtCommandBuffer &CommandBuffer,
578-
CG::StorageInitHelper CGData)
580+
std::shared_ptr<sycl::ext::oneapi::experimental::detail::exec_graph_impl>
581+
MExecGraph;
582+
583+
CGExecCommandBuffer(
584+
const sycl::detail::pi::PiExtCommandBuffer &CommandBuffer,
585+
const std::shared_ptr<
586+
sycl::ext::oneapi::experimental::detail::exec_graph_impl> &ExecGraph,
587+
CG::StorageInitHelper CGData)
579588
: CG(CGTYPE::ExecCommandBuffer, std::move(CGData)),
580-
MCommandBuffer(CommandBuffer) {}
589+
MCommandBuffer(CommandBuffer), MExecGraph(ExecGraph) {}
581590
};
582591

583592
} // namespace detail

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,37 @@ enum class graph_state {
8282
executable, ///< In executable state, the graph is ready to execute.
8383
};
8484

85+
enum class node_type {
86+
empty = 0,
87+
subgraph = 1,
88+
kernel = 2,
89+
memcpy = 3,
90+
memset = 4,
91+
memfill = 5,
92+
prefetch = 6,
93+
memadvise = 7,
94+
ext_oneapi_barrier = 8,
95+
host_task = 9
96+
};
97+
8598
/// Class representing a node in the graph, returned by command_graph::add().
8699
class __SYCL_EXPORT node {
100+
public:
101+
node() = delete;
102+
103+
/// Get the type of command associated with this node.
104+
node_type get_type() const;
105+
106+
/// Get a list of all the node dependencies of this node.
107+
std::vector<node> get_predecessors() const;
108+
109+
/// Get a list of all nodes which depend on this node.
110+
std::vector<node> get_successors() const;
111+
112+
/// Get the node associated with a SYCL event returned from a queue recording
113+
/// submission.
114+
static node get_node_from_event(event nodeEvent);
115+
87116
private:
88117
node(const std::shared_ptr<detail::node_impl> &Impl) : impl(Impl) {}
89118

@@ -259,6 +288,12 @@ class __SYCL_EXPORT modifiable_command_graph {
259288
/// as kernel args or memory access where applicable.
260289
void print_graph(const std::string path, bool verbose = false) const;
261290

291+
/// Get a list of all nodes contained in this graph.
292+
std::vector<node> get_nodes() const;
293+
294+
/// Get a list of all root nodes (nodes without dependencies) in this graph.
295+
std::vector<node> get_root_nodes() const;
296+
262297
protected:
263298
/// Constructor used internally by the runtime.
264299
/// @param Impl Detail implementation class to construct object with.

sycl/include/sycl/handler.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,6 +1777,14 @@ class __SYCL_EXPORT handler {
17771777
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
17781778
getCommandGraph() const;
17791779

1780+
/// Sets the user facing node type of this operation, used for operations
1781+
/// which are recorded to a graph. Since some operations may actually be a
1782+
/// different type than the user submitted, e.g. a fill() which is performed
1783+
/// as a kernel submission.
1784+
/// @param Type The actual type based on what handler functions the user
1785+
/// called.
1786+
void setUserFacingNodeType(ext::oneapi::experimental::node_type Type);
1787+
17801788
public:
17811789
handler(const handler &) = delete;
17821790
handler(handler &&) = delete;
@@ -2720,6 +2728,7 @@ class __SYCL_EXPORT handler {
27202728
checkIfPlaceholderIsBoundToHandler(Dst);
27212729

27222730
throwIfActionIsCreated();
2731+
setUserFacingNodeType(ext::oneapi::experimental::node_type::memfill);
27232732
// TODO add check:T must be an integral scalar value or a SYCL vector type
27242733
static_assert(isValidTargetForExplicitOp(AccessTarget),
27252734
"Invalid accessor target for the fill method.");
@@ -2758,6 +2767,7 @@ class __SYCL_EXPORT handler {
27582767
/// \param Count is the number of times to fill Pattern into Ptr.
27592768
template <typename T> void fill(void *Ptr, const T &Pattern, size_t Count) {
27602769
throwIfActionIsCreated();
2770+
setUserFacingNodeType(ext::oneapi::experimental::node_type::memfill);
27612771
static_assert(is_device_copyable<T>::value,
27622772
"Pattern must be device copyable");
27632773
parallel_for<__usmfill<T>>(range<1>(Count), [=](id<1> Index) {

0 commit comments

Comments
 (0)