Skip to content

Commit

Permalink
[ESI Runtime] Fix ordering problem with multiple reads
Browse files Browse the repository at this point in the history
ReadChannelPort wasn't accounting for multiple outstanding futures. Make
`read()` blocking, so as to simplify problem. Spawn 'readThread' to poll
for reads and fulfill promises.
  • Loading branch information
teqdruid committed Jun 15, 2024
1 parent 0a81a3e commit f1d632e
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 59 deletions.
39 changes: 16 additions & 23 deletions integration_test/Dialect/ESI/runtime/loopback.mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,20 @@

mysvc_recv = loopback.ports[esiaccel.AppID("mysvc_send")].read_port("send")
mysvc_recv.connect()
resp: bool = False
# Reads are non-blocking, so we need to poll.
while not resp:
print("i0 polling")
(resp, _) = mysvc_recv.read()
if not resp:
time.sleep(0.1)
print(f"i0 resp: {resp}")
mysvc_recv.read()
print("mysvc_recv.read() returned")

recv = loopback.ports[esiaccel.AppID("loopback_tohw")].write_port("recv")
recv.connect()
assert isinstance(recv.type, types.BitsType)

send = loopback.ports[esiaccel.AppID("loopback_fromhw")].read_port("send")
send.connect()
assert isinstance(send.type, types.BitsType)

data = 24
recv.write(int.to_bytes(data, 1, "little"))
resp = False
# Reads are non-blocking, so we need to poll.
resp_data: bytearray
while not resp:
print("polling")
(resp, resp_data) = send.read()
if not resp:
time.sleep(0.1)
resp_data: bytearray = send.read()
resp_int = int.from_bytes(resp_data, "little")

print(f"data: {data}")
Expand All @@ -76,20 +64,25 @@
if platform != "trace":
assert result == {"y": -22, "x": -21}

if platform != "trace":
print("Checking function call result ordering.")
future_result1 = myfunc(a=15, b=-32)
future_result2 = myfunc(a=32, b=47)
result2 = future_result2.result()
result1 = future_result1.result()
print(f"result1: {result1}")
print(f"result2: {result2}")
assert result1 == {"y": -32, "x": -31}, "result1 is incorrect"
assert result2 == {"y": 47, "x": 48}, "result2 is incorrect"

myfunc = d.ports[esiaccel.AppID("arrayFunc")]
arg_chan = myfunc.write_port("arg").connect()
result_chan = myfunc.read_port("result").connect()

arg = [-22]
arg_chan.write(arg)

result: Optional[List[int]] = None
resp = False
while not resp:
print("polling")
(resp, result) = result_chan.read()
if not resp:
time.sleep(0.1)
result: List[int] = result_chan.read()

print(f"result: {result}")
if platform != "trace":
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/ESI/runtime/cosim/include/cosim/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ namespace esi {
namespace cosim {

/// Thread safe queue. Just wraps std::queue protected with a lock.
// TODO: this is now in esiruntime proper. Remove this in the gRPC refactor
// wherein cosim will have a dependency on esiruntime.
template <typename T>
class TSQueue {
using Lock = std::lock_guard<std::mutex>;
Expand Down
62 changes: 52 additions & 10 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "esi/Common.h"
#include "esi/Types.h"
#include "esi/Utils.h"

#include <future>

Expand Down Expand Up @@ -54,18 +55,59 @@ class WriteChannelPort : public ChannelPort {
/// A ChannelPort which reads data from the accelerator.
class ReadChannelPort : public ChannelPort {
public:
using ChannelPort::ChannelPort;
ReadChannelPort(
const Type *type,
std::optional<std::chrono::microseconds> pollingInterval = std::nullopt)
: ChannelPort(type) {
if (pollingInterval)
this->pollingInterval = *pollingInterval;
else
this->pollingInterval = std::chrono::microseconds(100);
}

~ReadChannelPort() { disconnect(); }

/// In addition to the basic 'connect()' behavior, this also start a
/// 'read' thread.
virtual void connect() override;

/// In addition to the basic 'disconnect()' behavior, this also stops the
/// 'read' thread and 'join's on it.
virtual void disconnect() override;

/// Asynchronous read.
virtual std::future<MessageData> readAsync() {
auto p = new std::promise<MessageData>();
if (!promiseQueue.push(p)) {
throw std::runtime_error("Channel is disconnected");
delete p;
}
return p->get_future();
}

/// Specify a buffer to read into. Blocking. Basic API, will likely change
/// for performance and functionality reasons.
virtual void read(MessageData &outData) {
std::future<MessageData> f = readAsync();
f.wait();
outData = std::move(f.get());
}

protected:
std::chrono::microseconds pollingInterval;
std::thread readerThread;

// This queue is a producer-consumer. `readAsync` pushes a promise onto
// the queue. The 'readerThread' spins on the queue, pops one, waits for
// data, then fulfills the promise.
utils::TSQueue<std::promise<MessageData> *> promiseQueue;

/// Specify a buffer to read into. Non-blocking. Returns true if message
/// successfully recieved. Basic API, will likely change for performance
/// and functionality reasons.
virtual bool read(MessageData &) = 0;
/// Do the actual read.
virtual bool readInternal(MessageData &) = 0;

/// Asynchronous read. Returns a future which will be set when the message is
/// recieved. Could this subsume the synchronous read API?
/// The default implementation of this is really bad and should be overridden.
/// It simply polls `read` in a loop.
virtual std::future<MessageData> readAsync();
/// Reader thread main loop. Spins on a new message, then `pushAndPop`s a
/// new promise and fulfills the old one.
void readerThreadMain();
};

/// Services provide connections to 'bundles' -- collections of named,
Expand Down
55 changes: 55 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,67 @@
#define ESI_UTILS_H

#include <cstdint>
#include <mutex>
#include <optional>
#include <queue>
#include <string>

namespace esi {
namespace utils {
// Very basic base64 encoding.
void encodeBase64(const void *data, size_t size, std::string &out);

/// Thread safe queue. Just wraps std::queue protected with a lock. I think this
/// is horribly slow but works for now.
template <typename T>
class TSQueue {
using Lock = std::lock_guard<std::mutex>;

volatile bool stopped = false;
std::mutex m;
std::queue<T> q;

public:
/// Push onto the queue.
template <typename... E>
bool push(E... t) {
Lock l(m);
if (stopped)
return false;
q.emplace(t...);
return true;
}

/// Pop something off the queue but return nullopt if the queue is empty. Why
/// doesn't std::queue have anything like this?
std::optional<T> pop() {
Lock l(m);
if (q.size() == 0)
return std::nullopt;
auto t = q.front();
q.pop();
return t;
}

/// Signal that this channel should be shutdown.
void shutdown() {
Lock l(m);
stopped = true;
}

/// Signal that this channel should be restarted.
void restart() {
Lock l(m);
stopped = false;
}

/// Query if the channel is shutdown.
bool isShutdown() {
Lock l(m);
return stopped;
}
};

} // namespace utils
} // namespace esi

Expand Down
67 changes: 57 additions & 10 deletions lib/Dialect/ESI/runtime/cpp/lib/Ports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,61 @@ ReadChannelPort &BundlePort::getRawRead(const string &name) const {
return *read;
}

std::future<MessageData> ReadChannelPort::readAsync() {
// TODO: running this deferred is a horrible idea considering that it blocks!
// It's a hack since Capnp RPC refuses to work with multiple threads.
return std::async(std::launch::deferred, [this]() {
MessageData output;
while (!read(output)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
return output;
});
void ReadChannelPort::connect() {
ChannelPort::connect();
readerThread = std::thread(&ReadChannelPort::readerThreadMain, this);
}

void ReadChannelPort::disconnect() {
if (promiseQueue.isShutdown())
return;

ChannelPort::disconnect();
promiseQueue.shutdown();
if (readerThread.joinable())
readerThread.join();

// Clear the promise queue, cancelling any pending reads.
std::optional<std::promise<MessageData> *> promise;
do {
promise = promiseQueue.pop();
if (!promise)
break;
(*promise)->set_exception(
std::make_exception_ptr(runtime_error("Channel disconnected")));
delete *promise;
} while (promise);
}

void ReadChannelPort::readerThreadMain() {
while (!promiseQueue.isShutdown()) {
// Get a promise to fulfill.
std::optional<std::promise<MessageData> *> promise;
do {
promise = promiseQueue.pop();
if (promise)
break;
std::this_thread::sleep_for(pollingInterval);
} while (!promiseQueue.isShutdown());

// Make sure we're not shutting down.
if (promiseQueue.isShutdown())
return;

// Try to read a message.
MessageData data;
do {
if (readInternal(data))
break;
std::this_thread::sleep_for(pollingInterval);
} while (!promiseQueue.isShutdown());

// Ensure we're not shutting down.
if (promiseQueue.isShutdown())
(*promise)->set_exception(
std::make_exception_ptr(runtime_error("Channel disconnected")));
else
(*promise)->set_value(std::move(data));
delete *promise;
}
}
6 changes: 4 additions & 2 deletions lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class ReadCosimChannelPort : public ReadChannelPort {

// TODO: Replace this with a request to connect to the capnp thread.
virtual void connect() override {
ReadChannelPort::connect();
if (!ep)
throw runtime_error("Could not find channel '" + name +
"' in cosimulation");
Expand All @@ -201,10 +202,11 @@ class ReadCosimChannelPort : public ReadChannelPort {
ep->setInUse();
}
virtual void disconnect() override {
ReadChannelPort::disconnect();
if (ep)
ep->returnForUse();
}
virtual bool read(MessageData &) override;
virtual bool readInternal(MessageData &) override;

protected:
esi::cosim::Endpoint *ep;
Expand All @@ -213,7 +215,7 @@ class ReadCosimChannelPort : public ReadChannelPort {

} // namespace

bool ReadCosimChannelPort::read(MessageData &data) {
bool ReadCosimChannelPort::readInternal(MessageData &data) {
esi::cosim::Endpoint::MessageDataPtr msg;
if (!ep->getMessageToClient(msg))
return false;
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/ESI/runtime/cpp/lib/backends/Trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,14 @@ class ReadTraceChannelPort : public ReadChannelPort {
ReadTraceChannelPort(TraceAccelerator::Impl &impl, const Type *type)
: ReadChannelPort(type) {}

virtual bool read(MessageData &data) override;
virtual bool readInternal(MessageData &data) override;

private:
size_t numReads = 0;
};
} // namespace

bool ReadTraceChannelPort::read(MessageData &data) {
bool ReadTraceChannelPort::readInternal(MessageData &data) {
if ((++numReads & 0x1) == 1)
return false;

Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ PYBIND11_MODULE(esiCppAccel, m) {
p.write(dataVec);
});
py::class_<ReadChannelPort, ChannelPort>(m, "ReadChannelPort")
.def("read",
[](ReadChannelPort &p) -> py::object {
MessageData data;
if (!p.read(data))
return py::none();
return py::bytearray((const char *)data.getBytes(),
data.getSize());
})
.def(
"read",
[](ReadChannelPort &p) -> py::bytearray {
MessageData data;
p.read(data);
return py::bytearray((const char *)data.getBytes(), data.getSize());
},
"Read data from the channel. Blocking.")
.def("read_async", &ReadChannelPort::readAsync);

py::class_<BundlePort>(m, "BundlePort")
Expand Down
6 changes: 2 additions & 4 deletions lib/Dialect/ESI/runtime/python/esiaccel/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,17 +333,15 @@ def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
super().__init__(owner, cpp_port)
self.cpp_port: cpp.ReadChannelPort = cpp_port

def read(self) -> Tuple[bool, Optional[object]]:
def read(self) -> object:
"""Read a typed message from the channel. Returns a deserialized object of a
type defined by the port type."""

buffer = self.cpp_port.read()
if buffer is None:
return (False, None)
(msg, leftover) = self.type.deserialize(buffer)
if len(leftover) != 0:
raise ValueError(f"leftover bytes: {leftover}")
return (True, msg)
return msg


class BundlePort:
Expand Down

0 comments on commit f1d632e

Please sign in to comment.