Skip to content

Commit

Permalink
[ESI][Runtime] Poll method and optional service thread polling (#7460)
Browse files Browse the repository at this point in the history
Add a poll method to ports, a master poll method to the Accelerator, and the ability to poll from the service thread. Also, only spin up the service thread if it's requested.

The service thread polling (in particular) required some ownership changes: Accelerator objects now belong to the AcceleratorConnection so that the ports aren't destructed before the service thread gets shutdown (which causes an invalid memory access). This particular binding isn't ideal, is brittle, and will be an issue for anything doing the polling. Resolving #7457 should mitigate this issue.

Backends are now _required_ to call `disconnect` in their destructor.
  • Loading branch information
teqdruid authored Aug 8, 2024
1 parent caab217 commit ad91378
Show file tree
Hide file tree
Showing 17 changed files with 164 additions and 68 deletions.
3 changes: 3 additions & 0 deletions integration_test/Dialect/ESI/runtime/loopback.mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,7 @@
print(f"result: {result}")
if platform != "trace":
assert result == [-21, -22]

acc = None

print("PASS")
22 changes: 19 additions & 3 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Accelerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ class Accelerator : public HWModule {

/// Abstract class representing a connection to an accelerator. Actual
/// connections (e.g. to a co-simulation or actual device) are implemented by
/// subclasses.
/// subclasses. No methods in here are thread safe.
class AcceleratorConnection {
public:
AcceleratorConnection(Context &ctxt);
virtual ~AcceleratorConnection() = default;
virtual ~AcceleratorConnection();
Context &getCtxt() const { return ctxt; }

/// Disconnect from the accelerator cleanly.
Expand All @@ -89,7 +89,12 @@ class AcceleratorConnection {
virtual std::map<std::string, ChannelPort &>
requestChannelsFor(AppIDPath, const BundleType *) = 0;

AcceleratorServiceThread *getServiceThread() { return serviceThread.get(); }
/// Return a pointer to the accelerator 'service' thread (or threads). If the
/// thread(s) are not running, they will be started when this method is
/// called. `std::thread` is used. If users don't want the runtime to spin up
/// threads, don't call this method. `AcceleratorServiceThread` is owned by
/// AcceleratorConnection and governed by the lifetime of the this object.
AcceleratorServiceThread *getServiceThread();

using Service = services::Service;
/// Get a typed reference to a particular service type. Caller does *not* take
Expand All @@ -109,6 +114,10 @@ class AcceleratorConnection {
ServiceImplDetails details = {},
HWClientDetails clients = {});

/// Assume ownership of an accelerator object. Ties the lifetime of the
/// accelerator to this connection. Returns a raw pointer to the object.
Accelerator *takeOwnership(std::unique_ptr<Accelerator> accel);

protected:
/// Called by `getServiceImpl` exclusively. It wraps the pointer returned by
/// this in a unique_ptr and caches it. Separate this from the
Expand All @@ -128,6 +137,10 @@ class AcceleratorConnection {
std::map<ServiceCacheKey, std::unique_ptr<Service>> serviceCache;

std::unique_ptr<AcceleratorServiceThread> serviceThread;

/// List of accelerator objects owned by this connection. These are destroyed
/// when the connection dies or is shutdown.
std::vector<std::unique_ptr<Accelerator>> ownedAccelerators;
};

namespace registry {
Expand Down Expand Up @@ -173,6 +186,9 @@ class AcceleratorServiceThread {
addListener(std::initializer_list<ReadChannelPort *> listenPorts,
std::function<void(ReadChannelPort *, MessageData)> callback);

/// Poll this module.
void addPoll(HWModule &module);

/// Instruct the service thread to stop running.
void stop();

Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Design.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class HWModule {
return portIndex;
}

/// Master poll method. Calls the `poll` method on all locally owned ports and
/// the master `poll` method on all of the children. Returns true if any of
/// the `poll` calls returns true.
bool poll();

protected:
const std::optional<ModuleInfo> info;
const std::vector<std::unique_ptr<Instance>> children;
Expand Down
7 changes: 4 additions & 3 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Manifest.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ class Manifest {
// Modules which have designer specified metadata.
std::vector<ModuleInfo> getModuleInfos() const;

// Build a dynamic design hierarchy from the manifest.
std::unique_ptr<Accelerator>
buildAccelerator(AcceleratorConnection &acc) const;
// Build a dynamic design hierarchy from the manifest. The
// AcceleratorConnection owns the returned pointer so its lifetime is
// determined by the connection.
Accelerator *buildAccelerator(AcceleratorConnection &acc) const;

/// The Type Table is an ordered list of types. The offset can be used to
/// compactly and uniquely within a design. It does not include all of the
Expand Down
52 changes: 46 additions & 6 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,38 @@ namespace esi {
class ChannelPort {
public:
ChannelPort(const Type *type) : type(type) {}
virtual ~ChannelPort() { disconnect(); }
virtual ~ChannelPort() {}

/// Set up a connection to the accelerator. The buffer size is optional and
/// should be considered merely a hint. Individual implementations use it
/// however they like. The unit is number of messages of the port type.
virtual void connect(std::optional<unsigned> bufferSize = std::nullopt) {
connectImpl(bufferSize);
virtual void connect(std::optional<unsigned> bufferSize = std::nullopt) = 0;
virtual void disconnect() = 0;
virtual bool isConnected() const = 0;

/// Poll for incoming data. Returns true if data was read or written into a
/// buffer as a result of the poll. Calling the call back could (will) also
/// happen in that case. Some backends need this to be called periodically. In
/// the usual case, this will be called by a background thread, but the ESI
/// runtime does not want to assume that the host processes use standard
/// threads. If the user wants to provide their own threads, they need to call
/// this on each port occasionally. This is also called from the 'master' poll
/// method in the Accelerator class.
bool poll() {
if (isConnected())
return pollImpl();
return false;
}
virtual void disconnect() {}

const Type *getType() const { return type; }

private:
protected:
const Type *type;

/// Method called by poll() to actually poll the channel if the channel is
/// connected.
virtual bool pollImpl() { return false; }

/// Called by all connect methods to let backends initiate the underlying
/// connections.
virtual void connectImpl(std::optional<unsigned> bufferSize) {}
Expand All @@ -58,8 +75,19 @@ class WriteChannelPort : public ChannelPort {
public:
using ChannelPort::ChannelPort;

virtual void
connect(std::optional<unsigned> bufferSize = std::nullopt) override {
connectImpl(bufferSize);
connected = true;
}
virtual void disconnect() override { connected = false; }
virtual bool isConnected() const override { return connected; }

/// A very basic write API. Will likely change for performance reasons.
virtual void write(const MessageData &) = 0;

private:
volatile bool connected = false;
};

/// A ChannelPort which reads data from the accelerator. It has two modes:
Expand All @@ -72,6 +100,9 @@ class ReadChannelPort : public ChannelPort {
ReadChannelPort(const Type *type)
: ChannelPort(type), mode(Mode::Disconnected) {}
virtual void disconnect() override { mode = Mode::Disconnected; }
virtual bool isConnected() const override {
return mode != Mode::Disconnected;
}

//===--------------------------------------------------------------------===//
// Callback mode: To use a callback, connect with a callback function which
Expand Down Expand Up @@ -121,7 +152,7 @@ class ReadChannelPort : public ChannelPort {
protected:
/// Indicates the current mode of the channel.
enum Mode { Disconnected, Callback, Polling };
Mode mode;
volatile Mode mode;

/// Backends call this callback when new data is available.
std::function<bool(MessageData)> callback;
Expand Down Expand Up @@ -178,6 +209,15 @@ class BundlePort {
return const_cast<T *>(dynamic_cast<const T *>(this));
}

/// Calls `poll` on all channels in the bundle and returns true if any of them
/// returned true.
bool poll() {
bool result = false;
for (auto &channel : channels)
result |= channel.second.poll();
return result;
}

private:
AppID id;
std::map<std::string, ChannelPort &> channels;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/ESI/runtime/cpp/include/esi/backends/Trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class TraceAccelerator : public esi::AcceleratorConnection {
/// is opened for writing. For 'Read' mode, this file is opened for reading.
TraceAccelerator(Context &, Mode mode, std::filesystem::path manifestJson,
std::filesystem::path traceFile);
~TraceAccelerator() override;

/// Parse the connection string and instantiate the accelerator. Format is:
/// "<mode>:<manifest path>[:<traceFile>]".
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/ESI/runtime/cpp/include/esi/backends/Xrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class XrtAccelerator : public esi::AcceleratorConnection {
struct Impl;

XrtAccelerator(Context &, std::string xclbin, std::string kernelName);
~XrtAccelerator();
static std::unique_ptr<AcceleratorConnection>
connect(Context &, std::string connectionString);

Expand Down
48 changes: 43 additions & 5 deletions lib/Dialect/ESI/runtime/cpp/lib/Accelerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ using namespace esi::services;

namespace esi {
AcceleratorConnection::AcceleratorConnection(Context &ctxt)
: ctxt(ctxt), serviceThread(std::make_unique<AcceleratorServiceThread>()) {}
: ctxt(ctxt), serviceThread(nullptr) {}
AcceleratorConnection::~AcceleratorConnection() { disconnect(); }

AcceleratorServiceThread *AcceleratorConnection::getServiceThread() {
if (!serviceThread)
serviceThread = std::make_unique<AcceleratorServiceThread>();
return serviceThread.get();
}

services::Service *AcceleratorConnection::getService(Service::Type svcType,
AppIDPath id,
Expand All @@ -54,6 +61,13 @@ services::Service *AcceleratorConnection::getService(Service::Type svcType,
return cacheEntry.get();
}

Accelerator *
AcceleratorConnection::takeOwnership(std::unique_ptr<Accelerator> acc) {
Accelerator *ret = acc.get();
ownedAccelerators.push_back(std::move(acc));
return ret;
}

/// Get the path to the currently running executable.
static std::filesystem::path getExePath() {
#ifdef __linux__
Expand Down Expand Up @@ -224,18 +238,27 @@ struct AcceleratorServiceThread::Impl {
addListener(std::initializer_list<ReadChannelPort *> listenPorts,
std::function<void(ReadChannelPort *, MessageData)> callback);

void addTask(std::function<void(void)> task) {
std::lock_guard<std::mutex> g(m);
taskList.push_back(task);
}

private:
void loop();
volatile bool shutdown = false;
std::thread me;

// Protect the listeners std::map.
std::mutex listenerMutex;
// Protect the shared data structures.
std::mutex m;

// Map of read ports to callbacks.
std::map<ReadChannelPort *,
std::pair<std::function<void(ReadChannelPort *, MessageData)>,
std::future<MessageData>>>
listeners;

/// Tasks which should be called on every loop iteration.
std::vector<std::function<void(void)>> taskList;
};

void AcceleratorServiceThread::Impl::loop() {
Expand All @@ -245,6 +268,7 @@ void AcceleratorServiceThread::Impl::loop() {
std::function<void(ReadChannelPort *, MessageData)>,
MessageData>>
portUnlockWorkList;
std::vector<std::function<void(void)>> taskListCopy;
MessageData data;

while (!shutdown) {
Expand All @@ -256,7 +280,7 @@ void AcceleratorServiceThread::Impl::loop() {
// Check and gather data from all the read ports we are monitoring. Put the
// callbacks to be called later so we can release the lock.
{
std::lock_guard<std::mutex> g(listenerMutex);
std::lock_guard<std::mutex> g(m);
for (auto &[channel, cbfPair] : listeners) {
assert(channel && "Null channel in listener list");
std::future<MessageData> &f = cbfPair.second;
Expand All @@ -273,13 +297,22 @@ void AcceleratorServiceThread::Impl::loop() {

// Clear the worklist for the next iteration.
portUnlockWorkList.clear();

// Call any tasks that have been added. Copy it first so we can release the
// lock ASAP.
{
std::lock_guard<std::mutex> g(m);
taskListCopy = taskList;
}
for (auto &task : taskListCopy)
task();
}
}

void AcceleratorServiceThread::Impl::addListener(
std::initializer_list<ReadChannelPort *> listenPorts,
std::function<void(ReadChannelPort *, MessageData)> callback) {
std::lock_guard<std::mutex> g(listenerMutex);
std::lock_guard<std::mutex> g(m);
for (auto port : listenPorts) {
if (listeners.count(port))
throw std::runtime_error("Port already has a listener");
Expand Down Expand Up @@ -312,6 +345,11 @@ void AcceleratorServiceThread::addListener(
impl->addListener(listenPorts, callback);
}

void AcceleratorServiceThread::addPoll(HWModule &module) {
assert(impl && "Service thread not running");
impl->addTask([&module]() { module.poll(); });
}

void AcceleratorConnection::disconnect() {
if (serviceThread) {
serviceThread->stop();
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/lib/Design.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,13 @@ HWModule::HWModule(std::optional<ModuleInfo> info,
childIndex(buildIndex(this->children)), services(services),
ports(std::move(ports)), portIndex(buildIndex(this->ports)) {}

bool HWModule::poll() {
bool result = false;
for (auto &port : ports)
result |= port->poll();
for (auto &child : children)
result |= child->poll();
return result;
}

} // namespace esi
5 changes: 2 additions & 3 deletions lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,8 @@ std::vector<ModuleInfo> Manifest::getModuleInfos() const {
return ret;
}

std::unique_ptr<Accelerator>
Manifest::buildAccelerator(AcceleratorConnection &acc) const {
return impl->buildAccelerator(acc);
Accelerator *Manifest::buildAccelerator(AcceleratorConnection &acc) const {
return acc.takeOwnership(impl->buildAccelerator(acc));
}

const std::vector<const Type *> &Manifest::getTypeTable() const {
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void ReadChannelPort::connect(std::function<bool(MessageData)> callback,
throw std::runtime_error("Channel already connected");
mode = Mode::Callback;
this->callback = callback;
ChannelPort::connect(bufferSize);
connectImpl(bufferSize);
}

void ReadChannelPort::connect(std::optional<unsigned> bufferSize) {
Expand All @@ -71,7 +71,7 @@ void ReadChannelPort::connect(std::optional<unsigned> bufferSize) {
}
return true;
};
ChannelPort::connect(bufferSize);
connectImpl(bufferSize);
}

std::future<MessageData> ReadChannelPort::readAsync() {
Expand Down
Loading

0 comments on commit ad91378

Please sign in to comment.