Skip to content

Commit

Permalink
[ESI Runtime] Read ports now invoke callbacks
Browse files Browse the repository at this point in the history
We've switched from a polling 'pull' method to a callback-based 'push'
mechanism for read ports. Polling (via std::futures) is built on top of
the push mechanism.

The read-ordering problem has also been fixed by using std::futures
exclusively for polling schemes. They also allow for poll-wait-notify
schemes without any chances on our part.
  • Loading branch information
teqdruid committed Jun 22, 2024
1 parent 90954e2 commit 284f792
Show file tree
Hide file tree
Showing 17 changed files with 321 additions and 142 deletions.
7 changes: 2 additions & 5 deletions frontends/PyCDE/integration_test/test_software/esi_ram.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Optional
from typing import cast
import esiaccel as esi
import random
import sys
Expand Down Expand Up @@ -31,10 +31,7 @@

def read(addr: int) -> bytearray:
mem_read_addr.write([addr])
got_data = False
resp: Optional[bytearray] = None
while not got_data:
(got_data, resp) = mem_read_data.read()
resp = cast(bytearray, mem_read_data.read())
print(f"resp: {resp}")
return resp

Expand Down
6 changes: 1 addition & 5 deletions frontends/PyCDE/integration_test/test_software/esi_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import esiaccel as esi

import sys
from typing import Optional

platform = sys.argv[1]
acc = esi.AcceleratorConnection(platform, sys.argv[2])
Expand All @@ -22,10 +21,7 @@
data = 10234
send.write(data)
got_data = False
resp: Optional[int] = None
# Reads are non-blocking, so we need to poll.
while not got_data:
(got_data, resp) = recv.read()
resp = recv.read()

print(f"data: {data}")
print(f"resp: {resp}")
Expand Down
39 changes: 16 additions & 23 deletions integration_test/Dialect/ESI/runtime/loopback.mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,20 @@

mysvc_recv = loopback.ports[esiaccel.AppID("mysvc_send")].read_port("send")
mysvc_recv.connect()
resp: bool = False
# Reads are non-blocking, so we need to poll.
while not resp:
print("i0 polling")
(resp, _) = mysvc_recv.read()
if not resp:
time.sleep(0.1)
print(f"i0 resp: {resp}")
mysvc_recv.read()
print("mysvc_recv.read() returned")

recv = loopback.ports[esiaccel.AppID("loopback_tohw")].write_port("recv")
recv.connect()
assert isinstance(recv.type, types.BitsType)

send = loopback.ports[esiaccel.AppID("loopback_fromhw")].read_port("send")
send.connect()
assert isinstance(send.type, types.BitsType)

data = 24
recv.write(int.to_bytes(data, 1, "little"))
resp = False
# Reads are non-blocking, so we need to poll.
resp_data: bytearray
while not resp:
print("polling")
(resp, resp_data) = send.read()
if not resp:
time.sleep(0.1)
resp_data: bytearray = send.read()
resp_int = int.from_bytes(resp_data, "little")

print(f"data: {data}")
Expand All @@ -76,20 +64,25 @@
if platform != "trace":
assert result == {"y": -22, "x": -21}

if platform != "trace":
print("Checking function call result ordering.")
future_result1 = myfunc(a=15, b=-32)
future_result2 = myfunc(a=32, b=47)
result2 = future_result2.result()
result1 = future_result1.result()
print(f"result1: {result1}")
print(f"result2: {result2}")
assert result1 == {"y": -32, "x": -31}, "result1 is incorrect"
assert result2 == {"y": 47, "x": 48}, "result2 is incorrect"

myfunc = d.ports[esiaccel.AppID("arrayFunc")]
arg_chan = myfunc.write_port("arg").connect()
result_chan = myfunc.read_port("result").connect()

arg = [-22]
arg_chan.write(arg)

result: Optional[List[int]] = None
resp = False
while not resp:
print("polling")
(resp, result) = result_chan.read()
if not resp:
time.sleep(0.1)
result: List[int] = result_chan.read()

print(f"result: {result}")
if platform != "trace":
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/ESI/runtime/cosim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ install(FILES
COMPONENT ESIRuntime
)

add_library(EsiCosimGRPC OBJECT "${CMAKE_CURRENT_LIST_DIR}/cosim.proto")
add_library(EsiCosimGRPC SHARED "${CMAKE_CURRENT_LIST_DIR}/cosim.proto")
target_link_libraries(EsiCosimGRPC PUBLIC protobuf::libprotobuf gRPC::grpc++)
set(PROTO_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
target_include_directories(EsiCosimGRPC PUBLIC "$<BUILD_INTERFACE:${PROTO_BINARY_DIR}>")
Expand Down
18 changes: 12 additions & 6 deletions lib/Dialect/ESI/runtime/cosim/cosim_dpi_server/DpiEntryPoints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ static int validateSvOpenArray(const svOpenArrayHandle data,
// Lookups for registered ports. As a future optimization, change the DPI API to
// return a handle when registering wherein said handle is a pointer to a port.
std::map<std::string, ReadChannelPort &> readPorts;
std::map<ReadChannelPort *, std::future<MessageData>> readFutures;
std::map<std::string, WriteChannelPort &> writePorts;

// Register simulated device endpoints.
Expand All @@ -129,12 +130,15 @@ DPI int sv2cCosimserverEpRegister(char *endpointId, char *fromHostTypeIdC,
return -3;
}

if (!fromHostTypeId.empty())
readPorts.emplace(endpointId,
server->registerReadPort(endpointId, fromHostTypeId));
else
if (!fromHostTypeId.empty()) {
ReadChannelPort &port =
server->registerReadPort(endpointId, fromHostTypeId);
readPorts.emplace(endpointId, port);
readFutures.emplace(&port, port.readAsync());
} else {
writePorts.emplace(endpointId,
server->registerWritePort(endpointId, toHostTypeId));
}
return 0;
}

Expand All @@ -158,13 +162,15 @@ DPI int sv2cCosimserverEpTryGet(char *endpointId,
}

ReadChannelPort &port = portIt->second;
MessageData msg;
std::future<MessageData> &f = readFutures.at(&port);
// Poll for a message.
if (!port.read(msg)) {
if (f.wait_for(std::chrono::milliseconds(0)) != std::future_status::ready) {
// No message.
*dataSize = 0;
return 0;
}
MessageData msg = f.get();
f = port.readAsync();
log(endpointId, false, msg);

// Do the validation only if there's a message available. Since the
Expand Down
52 changes: 52 additions & 0 deletions lib/Dialect/ESI/runtime/cosim/include/cosim/Utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===- Utils.h - utility code for cosim -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef COSIM_UTILS_H
#define COSIM_UTILS_H

#include <mutex>
#include <optional>
#include <queue>

namespace esi {
namespace cosim {

/// Thread safe queue. Just wraps std::queue protected with a lock.
// TODO: this is now in esiruntime proper. Remove this in the gRPC refactor
// wherein cosim will have a dependency on esiruntime.
template <typename T>
class TSQueue {
using Lock = std::lock_guard<std::mutex>;

std::mutex m;
std::queue<T> q;

public:
/// Push onto the queue.
template <typename... E>
void push(E... t) {
Lock l(m);
q.emplace(t...);
}

/// Pop something off the queue but return nullopt if the queue is empty. Why
/// doesn't std::queue have anything like this?
std::optional<T> pop() {
Lock l(m);
if (q.size() == 0)
return std::nullopt;
auto t = q.front();
q.pop();
return t;
}
};

} // namespace cosim
} // namespace esi

#endif
21 changes: 12 additions & 9 deletions lib/Dialect/ESI/runtime/cosim/lib/RpcServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,10 @@ class RpcServerReadPort : public ReadChannelPort {
public:
RpcServerReadPort(Type *type) : ReadChannelPort(type) {}

bool read(MessageData &data) override {
std::optional<MessageData> msg = readQueue.pop();
if (!msg)
return false;
data = std::move(*msg);
return true;
void push(MessageData &data) {
while (!callback(data))
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}

utils::TSQueue<MessageData> readQueue;
};

/// Implements a simple write queue. The RPC server will pull messages from this
Expand Down Expand Up @@ -154,6 +149,11 @@ Impl::Impl(int port) : esiVersion(-1) {
}

void Impl::stop() {
for (auto &[name, port] : readPorts)
port->disconnect();
for (auto &[name, port] : writePorts)
port->disconnect();

/// Shutdown the server and wait for it to finish.
server->Shutdown();
server->Wait();
Expand All @@ -169,12 +169,14 @@ ReadChannelPort &Impl::registerReadPort(const std::string &name,
const std::string &type) {
auto port = new RpcServerReadPort(new Type(type));
readPorts.emplace(name, port);
port->connect();
return *port;
}
WriteChannelPort &Impl::registerWritePort(const std::string &name,
const std::string &type) {
auto port = new RpcServerWritePort(new Type(type));
writePorts.emplace(name, port);
port->connect();
return *port;
}

Expand Down Expand Up @@ -301,6 +303,7 @@ void RpcServerWriteReactor::threadLoop() {
return ret;
});
}
Finish(Status::OK);
}

/// When a client sends a message to a read port (write port on this end), start
Expand Down Expand Up @@ -334,7 +337,7 @@ Impl::SendToServer(CallbackServerContext *context,
std::string msgDataString = request->message().data();
MessageData data(reinterpret_cast<const uint8_t *>(msgDataString.data()),
msgDataString.size());
it->second->readQueue.push(std::move(data));
it->second->push(data);
reactor->Finish(Status::OK);
return reactor;
}
Expand Down
87 changes: 74 additions & 13 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Ports.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

#include "esi/Common.h"
#include "esi/Types.h"
#include "esi/Utils.h"

#include <cassert>
#include <future>

namespace esi {
Expand All @@ -31,15 +33,19 @@ namespace esi {
class ChannelPort {
public:
ChannelPort(const Type *type) : type(type) {}
virtual ~ChannelPort() = default;
virtual ~ChannelPort() { disconnect(); }

virtual void connect() {}
virtual void connect() { connectImpl(); }
virtual void disconnect() {}

const Type *getType() const { return type; }

private:
const Type *type;

/// Called by all connect methods to let backends initiate the underlying
/// connections.
virtual void connectImpl() {};
};

/// A ChannelPort which sends data to the accelerator.
Expand All @@ -51,21 +57,76 @@ class WriteChannelPort : public ChannelPort {
virtual void write(const MessageData &) = 0;
};

/// A ChannelPort which reads data from the accelerator.
/// A ChannelPort which reads data from the accelerator. It has two modes:
/// Callback and Polling which cannot be used at the same time. The mode is set
/// at connect() time. To change the mode, disconnect() and then connect()
/// again.
class ReadChannelPort : public ChannelPort {
enum Mode { Disconnected, Callback, Polling };

public:
using ChannelPort::ChannelPort;
ReadChannelPort(const Type *type) : ChannelPort(type) {}
virtual void disconnect() override { mode = Mode::Disconnected; }

//===--------------------------------------------------------------------===//
// Callback mode: To use a callback, connect with a callback function which
// will get called with incoming data. This function can be called from any
// thread. It shall return true to indicate that the data was consumed. False
// if it could not accept the data and should be tried again at some point in
// the future. Callback is not allowed to block and needs to execute quickly.
//
// TODO: Have the callback return something upon which the caller can check,
// wait, and be notified.
//===--------------------------------------------------------------------===//

virtual void connect(std::function<bool(MessageData)> callback);

//===--------------------------------------------------------------------===//
// Polling mode methods: To use futures or blocking reads, connect without any
// arguments. You will then be able to use readAsync() or read().
//===--------------------------------------------------------------------===//

/// Default max data queue size set at connect time.
static constexpr uint64_t DefaultMaxDataQueueMsgs = 32;

/// Connect to the channel in polling mode.
virtual void connect() override;

/// Asynchronous read.
virtual std::future<MessageData> readAsync();

/// Specify a buffer to read into. Non-blocking. Returns true if message
/// successfully recieved. Basic API, will likely change for performance
/// and functionality reasons.
virtual bool read(MessageData &) = 0;
/// Specify a buffer to read into. Blocking. Basic API, will likely change
/// for performance and functionality reasons.
virtual void read(MessageData &outData) {
std::future<MessageData> f = readAsync();
f.wait();
outData = std::move(f.get());
}

/// Asynchronous read. Returns a future which will be set when the message is
/// recieved. Could this subsume the synchronous read API?
/// The default implementation of this is really bad and should be overridden.
/// It simply polls `read` in a loop.
virtual std::future<MessageData> readAsync();
/// Set maximum number of messages to store in the dataQueue. 0 means no
/// limit. This is only used in polling mode and is set to default of 32 upon
/// connect.
void setMaxDataQueueMsgs(uint64_t maxMsgs) { maxDataQueueMsgs = maxMsgs; }

protected:
Mode mode;

/// Backends call this callback when new data is available.
std::function<bool(MessageData)> callback;

//===--------------------------------------------------------------------===//
// Polling mode members.
//===--------------------------------------------------------------------===//

/// Mutex to protect the two queues used for polling.
std::mutex pollingM;
/// Store incoming data here if there are no outstanding promises to be
/// fulfilled.
std::queue<MessageData> dataQueue;
/// Maximum number of messages to store in dataQueue. 0 means no limit.
uint64_t maxDataQueueMsgs;
/// Promises to be fulfilled when data is available.
std::queue<std::promise<MessageData>> promiseQueue;
};

/// Services provide connections to 'bundles' -- collections of named,
Expand Down
Loading

0 comments on commit 284f792

Please sign in to comment.