Skip to content

[SYCL][Graph] Throw exception when explicit add called on a graph recording a queue #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,7 @@ bool graph_impl::checkForCycles() {

void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
std::shared_ptr<node_impl> Dest) {
if (MRecordingQueues.size()) {
throw sycl::exception(make_error_code(sycl::errc::invalid),
"make_edge() cannot be called when a queue is "
"currently recording commands to a graph.");
}
throwIfGraphRecordingQueue("make_edge()");
if (Src == Dest) {
throw sycl::exception(
make_error_code(sycl::errc::invalid),
Expand Down Expand Up @@ -610,6 +606,7 @@ modifiable_command_graph::modifiable_command_graph(
PropList)) {}

node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
for (auto &D : Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
Expand All @@ -621,6 +618,7 @@ node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {

node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
const std::vector<node> &Deps) {
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
for (auto &D : Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
Expand Down
12 changes: 12 additions & 0 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,18 @@ class graph_impl {
void makeEdge(std::shared_ptr<node_impl> Src,
std::shared_ptr<node_impl> Dest);

/// Throws an invalid exception if this function is called
/// while a queue is recording commands to the graph.
/// @param ExceptionMsg Message to append to the exception message
void throwIfGraphRecordingQueue(const std::string ExceptionMsg) const {
if (MRecordingQueues.size()) {
throw sycl::exception(make_error_code(sycl::errc::invalid),
ExceptionMsg +
" cannot be called when a queue "
"is currently recording commands to a graph.");
}
}

private:
/// Iterate over the graph depth-first and run \p NodeFunc on each node.
/// @param NodeFunc A function which receives as input a node in the graph to
Expand Down
13 changes: 9 additions & 4 deletions sycl/test-e2e/Graph/Explicit/while_recording.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
//
// CHECK-NOT: LEAK

// Expected Fail as exception not implemented yet
// XFAIL: *

// Tests attempting to add a node to a command_graph while it is being
// recorded to by a queue is an error.
// The second run is to check that there are no leaks reported with the embedded
Expand All @@ -31,8 +28,16 @@ int main() {
Success = true;
}
}
assert(Success);

Graph.end_recording();
Success = false;
try {
Graph.add({});
} catch (sycl::exception &E) {
Success = E.code() == static_cast<int>(errc::invalid)
}
assert(Success);

Graph.end_recording();
return 0;
}