Skip to content

[SYCL][Graph] Make SYCL-Graph functions thread-safe #10778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 23, 2023
17 changes: 16 additions & 1 deletion sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ void exec_graph_impl::createCommandBuffers(sycl::device Device) {
}

exec_graph_impl::~exec_graph_impl() {
WriteLock LockImpl(MGraphImpl->MMutex);

// clear all recording queue if not done before (no call to end_recording)
MGraphImpl->clearQueues();

Expand All @@ -368,6 +370,8 @@ exec_graph_impl::~exec_graph_impl() {
sycl::event
exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
sycl::detail::CG::StorageInitHelper CGData) {
WriteLock Lock(MMutex);

auto CreateNewEvent([&]() {
auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
NewEvent->setContextImpl(Queue->getContextImplPtr());
Expand Down Expand Up @@ -481,6 +485,7 @@ node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}
Expand All @@ -492,6 +497,7 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl =
impl->add(impl, CGF, {}, DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
Expand All @@ -503,13 +509,17 @@ void modifiable_command_graph::make_edge(node &Src, node &Dest) {
std::shared_ptr<detail::node_impl> ReceiverImpl =
sycl::detail::getSyclObjImpl(Dest);

graph_impl::WriteLock Lock(impl->MMutex);
SenderImpl->registerSuccessor(ReceiverImpl,
SenderImpl); // register successor
impl->removeRoot(ReceiverImpl); // remove receiver from root node list
}

command_graph<graph_state::executable>
modifiable_command_graph::finalize(const sycl::property_list &) const {
// Graph is read and written in this scope so we lock
// this graph with full priviledges.
graph_impl::WriteLock Lock(impl->MMutex);
return command_graph<graph_state::executable>{this->impl,
this->impl->getContext()};
}
Expand All @@ -536,6 +546,7 @@ bool modifiable_command_graph::begin_recording(queue &RecordingQueue) {

if (QueueImpl->getCommandGraph() == nullptr) {
QueueImpl->setCommandGraph(impl);
graph_impl::WriteLock Lock(impl->MMutex);
impl->addQueue(QueueImpl);
return true;
}
Expand All @@ -557,12 +568,16 @@ bool modifiable_command_graph::begin_recording(
return QueueStateChanged;
}

bool modifiable_command_graph::end_recording() { return impl->clearQueues(); }
bool modifiable_command_graph::end_recording() {
graph_impl::WriteLock Lock(impl->MMutex);
return impl->clearQueues();
}

bool modifiable_command_graph::end_recording(queue &RecordingQueue) {
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
if (QueueImpl->getCommandGraph() == impl) {
QueueImpl->setCommandGraph(nullptr);
graph_impl::WriteLock Lock(impl->MMutex);
impl->removeQueue(QueueImpl);
return true;
}
Expand Down
171 changes: 167 additions & 4 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <functional>
#include <list>
#include <set>
#include <shared_mutex>

namespace sycl {
inline namespace _V1 {
Expand Down Expand Up @@ -167,6 +168,62 @@ class node_impl {
return nullptr;
}

/// Tests is the caller is similar to Node
/// @return True if the two nodes are similar
bool isSimilar(std::shared_ptr<node_impl> Node) {
if (MCGType != Node->MCGType)
return false;

if (MSuccessors.size() != Node->MSuccessors.size())
return false;

if (MPredecessors.size() != Node->MPredecessors.size())
return false;

if ((MCGType == sycl::detail::CG::CGTYPE::Kernel) &&
(Node->MCGType == sycl::detail::CG::CGTYPE::Kernel)) {
sycl::detail::CGExecKernel *ExecKernelA =
static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
sycl::detail::CGExecKernel *ExecKernelB =
static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get());

if (ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) != 0)
return false;
}
return true;
}

/// Recursive traversal of successor nodes checking for
/// equivalent node successions in Node
/// @param Node pointer to the starting node for structure comparison
/// @return true is same structure found, false otherwise
bool checkNodeRecursive(std::shared_ptr<node_impl> Node) {
size_t FoundCnt = 0;
for (std::shared_ptr<node_impl> SuccA : MSuccessors) {
for (std::shared_ptr<node_impl> SuccB : Node->MSuccessors) {
if (isSimilar(Node) && SuccA->checkNodeRecursive(SuccB)) {
FoundCnt++;
break;
}
}
}
if (FoundCnt != MSuccessors.size()) {
return false;
}

return true;
}

/// Recusively computes the number of successor nodes
/// @return number of successor nodes
size_t depthSearchCount() const {
size_t NumberOfNodes = 1;
for (const auto &Succ : MSuccessors) {
NumberOfNodes += Succ->depthSearchCount();
}
return NumberOfNodes;
}

private:
/// Creates a copy of the node's CG by casting to it's actual type, then using
/// that to copy construct and create a new unique ptr from that copy.
Expand All @@ -180,17 +237,19 @@ class node_impl {
/// Implementation details of command_graph<modifiable>.
class graph_impl {
public:
using ReadLock = std::shared_lock<std::shared_mutex>;
using WriteLock = std::unique_lock<std::shared_mutex>;

/// Protects all the fields that can be changed by class' methods.
mutable std::shared_mutex MMutex;

/// Constructor.
/// @param SyclContext Context to use for graph.
/// @param SyclDevice Device to create nodes with.
graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice)
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
MEventsMap(), MInorderQueueMap() {}

/// Insert node into list of root nodes.
/// @param Root Node to add to list of root nodes.
void addRoot(const std::shared_ptr<node_impl> &Root);

/// Remove node from list of root nodes.
/// @param Root Node to remove from list of root nodes.
void removeRoot(const std::shared_ptr<node_impl> &Root);
Expand Down Expand Up @@ -264,6 +323,7 @@ class graph_impl {
/// @return Event associated with node.
std::shared_ptr<sycl::detail::event_impl>
getEventForNode(std::shared_ptr<node_impl> NodeImpl) const {
ReadLock Lock(MMutex);
if (auto EventImpl = std::find_if(
MEventsMap.begin(), MEventsMap.end(),
[NodeImpl](auto &it) { return it.second == NodeImpl; });
Expand Down Expand Up @@ -315,6 +375,95 @@ class graph_impl {
MInorderQueueMap[QueueWeakPtr] = Node;
}

/// Checks if the graph_impl of Graph has a similar structure to
/// the graph_impl of the caller.
/// Graphs are considered similar if they have same numbers of nodes
/// of the same type with similar predecessor and successor nodes (number and
/// type). Two nodes are considered similar if they have the same
/// command-group type. For command-groups of type "kernel", the "signature"
/// of the kernel is also compared (i.e. the name of the command-group).
/// @param Graph if reference to the graph to compare with.
/// @param DebugPrint if set to true throw exception with additional debug
/// information about the spotted graph differences.
/// @return true if the two graphs are similar, false otherwise
bool hasSimilarStructure(std::shared_ptr<detail::graph_impl> Graph,
bool DebugPrint = false) const {
if (this == Graph.get())
return true;

if (MContext != Graph->MContext) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"MContext are not the same.");
}
return false;
}

if (MDevice != Graph->MDevice) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"MDevice are not the same.");
}
return false;
}

if (MEventsMap.size() != Graph->MEventsMap.size()) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"MEventsMap sizes are not the same.");
}
return false;
}

if (MInorderQueueMap.size() != Graph->MInorderQueueMap.size()) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"MInorderQueueMap sizes are not the same.");
}
return false;
}

if (MRoots.size() != Graph->MRoots.size()) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"MRoots sizes are not the same.");
}
return false;
}

size_t RootsFound = 0;
for (std::shared_ptr<node_impl> NodeA : MRoots) {
for (std::shared_ptr<node_impl> NodeB : Graph->MRoots) {
if (NodeA->isSimilar(NodeB)) {
if (NodeA->checkNodeRecursive(NodeB)) {
RootsFound++;
break;
}
}
}
}

if (RootsFound != MRoots.size()) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Root Nodes do NOT match.");
}
return false;
}

return true;
}

// Returns the number of nodes in the Graph
// @return Number of nodes in the Graph
size_t getNumberOfNodes() const {
size_t NumberOfNodes = 0;
for (const auto &Node : MRoots) {
NumberOfNodes += Node->depthSearchCount();
}
return NumberOfNodes;
}

private:
/// Context associated with this graph.
sycl::context MContext;
Expand All @@ -333,11 +482,21 @@ class graph_impl {
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
MInorderQueueMap;

/// Insert node into list of root nodes.
/// @param Root Node to add to list of root nodes.
void addRoot(const std::shared_ptr<node_impl> &Root);
};

/// Class representing the implementation of command_graph<executable>.
class exec_graph_impl {
public:
using ReadLock = std::shared_lock<std::shared_mutex>;
using WriteLock = std::unique_lock<std::shared_mutex>;

/// Protects all the fields that can be changed by class' methods.
mutable std::shared_mutex MMutex;

/// Constructor.
/// @param Context Context to create graph with.
/// @param GraphImpl Modifiable graph implementation to create with.
Expand Down Expand Up @@ -413,6 +572,10 @@ class exec_graph_impl {
std::list<std::shared_ptr<node_impl>> MSchedule;
/// Pointer to the modifiable graph impl associated with this executable
/// graph.
/// Thread-safe implementation note: in the current implementation
/// multiple exec_graph_impl can reference the same graph_impl object.
/// This specificity must be taken into account when trying to lock
/// the graph_impl mutex from an exec_graph_impl to avoid deadlock.
std::shared_ptr<graph_impl> MGraphImpl;
/// Map of devices to command buffers.
std::unordered_map<sycl::device, sycl::detail::pi::PiExtCommandBuffer>
Expand Down
1 change: 1 addition & 0 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,7 @@ class queue_impl {

void setCommandGraph(
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> Graph) {
std::lock_guard<std::mutex> Lock(MMutex);
MGraph = Graph;
}

Expand Down
18 changes: 18 additions & 0 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,11 @@ event handler::finalize() {
std::shared_ptr<ext::oneapi::experimental::detail::node_impl> NodeImpl =
nullptr;

// GraphImpl is read and written in this scope so we lock this graph
// with full priviledges.
ext::oneapi::experimental::detail::graph_impl::WriteLock Lock(
GraphImpl->MMutex);

// Create a new node in the graph representing this command-group
if (MQueue->isInOrder()) {
// In-order queues create implicit linear dependencies between nodes.
Expand Down Expand Up @@ -1332,15 +1337,28 @@ void handler::ext_oneapi_graph(
Graph) {
MCGType = detail::CG::ExecCommandBuffer;
auto GraphImpl = detail::getSyclObjImpl(Graph);
// GraphImpl is only read in this scope so we lock this graph for read only
ext::oneapi::experimental::detail::graph_impl::ReadLock Lock(
GraphImpl->MMutex);

std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> ParentGraph;
if (MQueue) {
ParentGraph = MQueue->getCommandGraph();
} else {
ParentGraph = MGraph;
}

ext::oneapi::experimental::detail::graph_impl::WriteLock ParentLock;
// If a parent graph is set that means we are adding or recording a subgraph
if (ParentGraph) {
// ParentGraph is read and written in this scope so we lock this graph
// with full priviledges.
// We only lock for Record&Replay API because the graph has already been
// lock if this function was called from the explicit API function add
if (MQueue) {
ParentLock = ext::oneapi::experimental::detail::graph_impl::WriteLock(
ParentGraph->MMutex);
}
// Store the node representing the subgraph in the handler so that we can
// return it to the user later.
MSubgraphNode = ParentGraph->addSubgraphNodes(GraphImpl->getSchedule());
Expand Down
2 changes: 1 addition & 1 deletion sycl/test-e2e/Graph/Explicit/basic_usm.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// REQUIRES: level_zero, gpu
// RUN: %{build} -o %t.out
// RUN: %{build_pthread_inc} -o %t.out
// RUN: %{run} %t.out
// Extra run to check for leaks in Level Zero using ZE_DEBUG
// RUN: %if ext_oneapi_level_zero %{env ZE_DEBUG=4 %{run} %t.out 2>&1 | FileCheck %s %}
Expand Down
9 changes: 9 additions & 0 deletions sycl/test-e2e/Graph/Inputs/basic_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
// and submission of the graph.

#include "../graph_common.hpp"
#include <thread>

int main() {
queue Queue;

using T = int;

const unsigned NumThreads = std::thread::hardware_concurrency();
std::vector<T> DataA(Size), DataB(Size), DataC(Size);

std::iota(DataA.begin(), DataA.end(), 1);
Expand All @@ -32,8 +34,15 @@ int main() {
// Add commands to graph
add_nodes(Graph, Queue, Size, PtrA, PtrB, PtrC);

Barrier SyncPoint{NumThreads};

auto GraphExec = Graph.finalize();

auto SubmitGraph = [&]() {
SyncPoint.wait();
Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(GraphExec); });
};

event Event;
for (unsigned n = 0; n < Iterations; n++) {
Event = Queue.submit([&](handler &CGH) {
Expand Down
Loading