Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/include/framework/tcp_connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class tcp_connection : public std::enable_shared_from_this<tcp_connection> {
co_return;
}

const std::uint32_t _length = static_cast<std::uint32_t>(_payload_size);
const auto _length = static_cast<std::uint32_t>(_payload_size);

std::array<unsigned char, 4> _header;
_header[0] = static_cast<unsigned char>(_length >> 24 & 0xFF);
Expand Down
99 changes: 55 additions & 44 deletions src/objects/framework/tcp_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,76 +20,87 @@
#include <framework/tcp_session.hpp>

namespace framework {

namespace {
async_of<std::tuple<boost::system::error_code, std::size_t>> read_exactly(boost::asio::ip::tcp::socket &socket,
boost::asio::streambuf &buffer, const std::size_t n) {
co_return co_await async_read(socket, buffer, boost::asio::transfer_exactly(n), boost::asio::as_tuple);
}

std::uint32_t read_uint32_from_buffer(boost::asio::streambuf &buffer) {
static_assert(HEADER_SIZE == 4, "HEADER_SIZE must be 4");
std::istream _input_stream(&buffer);
std::array<unsigned char, HEADER_SIZE> _header{};
_input_stream.read(reinterpret_cast<char *>(_header.data()), HEADER_SIZE);
return static_cast<std::uint32_t>(_header[0]) << 24 | static_cast<std::uint32_t>(_header[1]) << 16 |
static_cast<std::uint32_t>(_header[2]) << 8 | static_cast<std::uint32_t>(_header[3]) << 0;
}

async_of<void> notify_error_and_close(const shared_tcp_service service, const shared_tcp_connection connection,
boost::asio::ip::tcp::socket &socket, const std::exception error) {
if (service->handlers()->on_error()) co_await service->handlers()->on_error()(service, connection, error);
if (service->handlers()->on_disconnected()) co_await service->handlers()->on_disconnected()(service, connection);
boost::system::error_code _ec;
socket.shutdown(boost::asio::socket_base::shutdown_both, _ec);
socket.close(_ec);
co_return;
}

async_of<void> notify_disconnected_if_present(const shared_tcp_service service, const shared_tcp_connection connection) {
if (service->handlers()->on_disconnected()) co_await service->handlers()->on_disconnected()(service, connection);
co_return;
}
} // namespace

async_of<void> tcp_session(const shared_state state, const shared_tcp_service service, const shared_tcp_connection connection) {
auto _cancellation_state = co_await boost::asio::this_coro::cancellation_state;
boost::ignore_unused(state);

const auto _cancel_state = co_await boost::asio::this_coro::cancellation_state;
if (service->handlers()->on_accepted()) co_await service->handlers()->on_accepted()(service, connection);

auto &socket = connection->get_stream()->socket();
auto &buffer = connection->get_buffer();
auto &_socket = connection->get_stream()->socket();
auto &_buffer = connection->get_buffer();

while (!_cancellation_state.cancelled()) {
while (!_cancel_state.cancelled()) {
connection->get_stream()->expires_after(std::chrono::minutes(60));

auto [_header_ec, _header_read_bytes] =
co_await async_read(socket, buffer, boost::asio::transfer_exactly(HEADER_SIZE), boost::asio::as_tuple);

if (_header_ec) {
if (service->handlers()->on_disconnected()) co_await service->handlers()->on_disconnected()(service, connection);
if (auto [_read_header_ec, _header_read_bytes] = co_await read_exactly(_socket, _buffer, HEADER_SIZE); _read_header_ec) {
boost::ignore_unused(_header_read_bytes);
co_await notify_disconnected_if_present(service, connection);
co_return;
}

std::uint32_t _payload_length = 0;
std::istream _input_stream(&buffer);
unsigned char _header[HEADER_SIZE];
_input_stream.read(reinterpret_cast<char *>(_header), HEADER_SIZE);
_payload_length = (static_cast<std::uint32_t>(_header[0]) << 24) | (static_cast<std::uint32_t>(_header[1]) << 16) |
(static_cast<std::uint32_t>(_header[2]) << 8) | (static_cast<std::uint32_t>(_header[3]) << 0);

if (_payload_length == 0) {
continue;
}

if (_payload_length > MAX_FRAME_SIZE) {
const errors::tcp::frame_too_large _error;
if (service->handlers()->on_error()) co_await service->handlers()->on_error()(service, connection, _error);
if (service->handlers()->on_disconnected()) co_await service->handlers()->on_disconnected()(service, connection);
boost::system::error_code ignored_ec;
socket.shutdown(boost::asio::socket_base::shutdown_both, ignored_ec);
socket.close(ignored_ec);
const std::uint32_t _payload_size = read_uint32_from_buffer(_buffer);
if (_payload_size == 0) continue;
if (_payload_size > MAX_FRAME_SIZE) {
co_await notify_error_and_close(service, connection, _socket, errors::tcp::frame_too_large{});
co_return;
}

auto [_payload_ec, _bytes_transferred] =
co_await async_read(socket, buffer, boost::asio::transfer_exactly(_payload_length), boost::asio::as_tuple);
if (_payload_ec) {
const errors::tcp::on_read_error _error;
if (service->handlers()->on_error()) co_await service->handlers()->on_error()(service, connection, _error);
if (service->handlers()->on_disconnected()) co_await service->handlers()->on_disconnected()(service, connection);
if (auto [_read_payload_ec, _payload_read_bytes] = co_await read_exactly(_socket, _buffer, _payload_size); _read_payload_ec) {
boost::ignore_unused(_payload_read_bytes);
co_await notify_error_and_close(service, connection, _socket, errors::tcp::on_read_error{});
co_return;
}

if (static_cast<bool>(_cancellation_state.cancelled())) {
if (service->handlers()->on_disconnected()) co_await service->handlers()->on_disconnected()(service, connection);
if (static_cast<bool>(_cancel_state.cancelled())) {
co_await notify_disconnected_if_present(service, connection);
co_return;
}

std::string _payload;
_payload.resize(_payload_length);
_payload.resize(_payload_size);
{
std::istream is(&buffer);
is.read(_payload.data(), _payload_length);
std::istream _input_stream(&_buffer);
_input_stream.read(_payload.data(), _payload_size);
}

if (service->handlers()->on_read()) {
co_await service->handlers()->on_read()(service, connection, std::move(_payload));
}
if (service->handlers()->on_read()) co_await service->handlers()->on_read()(service, connection, std::move(_payload));
}

if (service->handlers()->on_disconnected()) co_await service->handlers()->on_disconnected()(service, connection);
co_await notify_disconnected_if_present(service, connection);

if (!connection->get_stream()->socket().is_open()) co_return;

connection->get_stream()->socket().shutdown(socket::shutdown_send);
_socket.shutdown(boost::asio::socket_base::shutdown_send);
}
} // namespace framework
Loading