Skip to content

[SYCL][Graph] Implementation of whole graph update #13220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions sycl/doc/design/CommandGraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ yet been implemented.

### Design Challenges

Graph update faces significant design challenges in SYCL:
#### Explicit Update

Explicit updates of individual nodes faces significant design challenges in SYCL:

* Lambda capture order is explicitly undefined in C++, so the user cannot reason
about the indices of arguments captured by kernel lambdas.
Expand All @@ -256,9 +258,18 @@ can be used:
extension](../extensions/proposed/sycl_ext_oneapi_free_function_kernels.asciidoc)
* OpenCL interop kernels created from SPIR-V source at runtime.

A possible future workaround lambda capture issues could be "Whole-Graph Update"
where if we can guarantee that lambda capture order is the same across two
different recordings we can then match parameter order when updating.
A workaround for the lambda capture issues is the "Whole-Graph Update" feature.
Since the lambda capture order is the same across two different recordings, we
can match the parameter order when updating.

#### Whole-Graph Update

The current implementation of the whole-graph update feature relies on the
assumption that both graphs should have a similar topology. Currently, the
implementation only checks that both graphs have an identical number of nodes
and that each node contains the same number of edges. Further investigation
should be done to see if it is possible to add extra checks (e.g. check that the
nodes and edges were added in the same order).

### Scheduler Integration

Expand Down
56 changes: 51 additions & 5 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,9 @@ void exec_graph_impl::createCommandBuffers(
exec_graph_impl::exec_graph_impl(sycl::context Context,
const std::shared_ptr<graph_impl> &GraphImpl,
const property_list &PropList)
: MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(), MContext(Context),
MRequirements(), MExecutionEvents(),
: MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(),
MDevice(GraphImpl->getDevice()), MContext(Context), MRequirements(),
MExecutionEvents(),
MIsUpdatable(PropList.has_property<property::graph::updatable>()) {

// If the graph has been marked as updatable then check if the backend
Expand Down Expand Up @@ -1155,9 +1156,56 @@ void exec_graph_impl::duplicateNodes() {
MNodeStorage.insert(MNodeStorage.begin(), NewNodes.begin(), NewNodes.end());
}

void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {

if (MDevice != GraphImpl->getDevice()) {
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"Cannot update using a graph created with a different device.");
}
if (MContext != GraphImpl->getContext()) {
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"Cannot update using a graph created with a different context.");
}

if (MNodeStorage.size() != GraphImpl->MNodeStorage.size()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Cannot update using a graph with a different "
"topology. Mismatch found in the number of nodes.");
} else {
for (uint32_t i = 0; i < MNodeStorage.size(); ++i) {
if (MNodeStorage[i]->MSuccessors.size() !=
GraphImpl->MNodeStorage[i]->MSuccessors.size() ||
MNodeStorage[i]->MPredecessors.size() !=
GraphImpl->MNodeStorage[i]->MPredecessors.size()) {
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"Cannot update using a graph with a different topology. Mismatch "
"found in the number of edges.");
}

if (MNodeStorage[i]->MCGType != GraphImpl->MNodeStorage[i]->MCGType) {
throw sycl::exception(
sycl::make_error_code(errc::invalid),
"Cannot update using a graph with mismatched node types. Each pair "
"of nodes being updated must have the same type");
}
}
}

for (uint32_t i = 0; i < MNodeStorage.size(); ++i) {
MIDCache.insert(
std::make_pair(GraphImpl->MNodeStorage[i]->MID, MNodeStorage[i]));
}

update(GraphImpl->MNodeStorage);
}

void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
this->update(std::vector<std::shared_ptr<node_impl>>{Node});
}

void exec_graph_impl::update(
const std::vector<std::shared_ptr<node_impl>> Nodes) {

Expand Down Expand Up @@ -1598,9 +1646,7 @@ void executable_command_graph::finalizeImpl() {

void executable_command_graph::update(
const command_graph<graph_state::modifiable> &Graph) {
(void)Graph;
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Method not yet implemented");
impl->update(sycl::detail::getSyclObjImpl(Graph));
}

void executable_command_graph::update(const node &Node) {
Expand Down
7 changes: 7 additions & 0 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,10 @@ class exec_graph_impl {
void createCommandBuffers(sycl::device Device,
std::shared_ptr<partition> &Partition);

/// Query for the device tied to this graph.
/// @return Device associated with graph.
sycl::device getDevice() const { return MDevice; }

/// Query for the context tied to this graph.
/// @return Context associated with graph.
sycl::context getContext() const { return MContext; }
Expand Down Expand Up @@ -1320,6 +1324,7 @@ class exec_graph_impl {
return MRequirements;
}

void update(std::shared_ptr<graph_impl> GraphImpl);
void update(std::shared_ptr<node_impl> Node);
void update(const std::vector<std::shared_ptr<node_impl>> Nodes);

Expand Down Expand Up @@ -1408,6 +1413,8 @@ class exec_graph_impl {
/// Map of nodes in the exec graph to the partition number to which they
/// belong.
std::unordered_map<std::shared_ptr<node_impl>, int> MPartitionNodes;
/// Device associated with this executable graph.
sycl::device MDevice;
/// Context associated with this executable graph.
sycl::context MContext;
/// List of requirements for enqueueing this command graph, accumulated from
Expand Down

This file was deleted.

104 changes: 0 additions & 104 deletions sycl/test-e2e/Graph/Inputs/double_buffer.cpp

This file was deleted.

Loading