Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/session.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class PGSession : public std::enable_shared_from_this<PGSession>
handle_ssl_negotiation();
}

static boost::asio::io_context &get_io_context();

private:
void parse_startup_params(const char *data, size_t length);
void handle_ssl_negotiation();
Expand Down
5 changes: 2 additions & 3 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ int main(int argc, char *argv[])
}
PINFO << "Start on port " << port;

boost::asio::io_context io_context;
Server server(io_context, port);
io_context.run();
Server server(PGSession::get_io_context(), port);
PGSession::get_io_context().run();
}
catch (std::exception &e)
{
Expand Down
55 changes: 44 additions & 11 deletions src/session.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
#include <boost/asio/thread_pool.hpp>
#include <boost/asio/post.hpp>
#include <boost/thread.hpp>
#include <boost/asio.hpp>
#include <boost/bind/bind.hpp>

#include "session.hpp"
#include <memory>
#include <set>
#include "log.hpp"
#include "db.hpp"

#include <memory>
#include <set>

using boost::asio::ip::tcp;
static boost::asio::thread_pool thread_pool_(4); // 4 threads in the pool

std::unordered_map<std::string, uint32_t> duckdb_to_pg_type = {
{"BOOLEAN", 16}, // PG: bool
Expand All @@ -23,6 +32,13 @@ std::unordered_map<std::string, uint32_t> duckdb_to_pg_type = {
// 添加更多类型映射...
};

boost::asio::io_context &
PGSession::get_io_context()
{
static boost::asio::io_context io_context_;
return io_context_;
}

// 解析启动参数
void PGSession::parse_startup_params(const char *data, size_t length)
{
Expand Down Expand Up @@ -189,7 +205,13 @@ void PGSession::handle_simple_query()
if (!ec)
{
// 处理查询并返回结果
self->process_query();
boost::asio::post(thread_pool_,
[self]()
{
self->process_query();
});
// 继续处理下一个查询
self->handle_query();
}
});
}
Expand Down Expand Up @@ -366,9 +388,6 @@ void PGSession::process_query()
}
// 3. 发送ReadyForQuery
send_ready_for_query();

// 继续处理下一个查询
handle_query();
}

// 发送行描述
Expand Down Expand Up @@ -439,7 +458,9 @@ void PGSession::send_row_description(const std::vector<ColumnDesc> &columns)
std::copy(reinterpret_cast<uint8_t *>(&net_len),
reinterpret_cast<uint8_t *>(&net_len) + 4,
msg.begin() + len_pos);
asio::write(socket_, asio::buffer(msg));

boost::asio::post(get_io_context(), [self = shared_from_this(), msg]()
{ asio::write(self->socket_, asio::buffer(msg)); });
}

// 发送数据行
Expand Down Expand Up @@ -475,7 +496,11 @@ void PGSession::send_data_row(const std::vector<std::string> &values)
msg.insert(msg.end(), val.begin(), val.end());
}

asio::write(socket_, asio::buffer(msg));
boost::asio::post(get_io_context(),
[self = shared_from_this(), msg]()
{
asio::write(self->socket_, asio::buffer(msg));
});
}

// 发送命令完成
Expand All @@ -493,12 +518,20 @@ void PGSession::send_command_complete(const std::string &tag)
msg.insert(msg.end(), tag.begin(), tag.end());
msg.push_back('\0');

asio::write(socket_, asio::buffer(msg));
boost::asio::post(get_io_context(),
[self = shared_from_this(), msg]()
{
asio::write(self->socket_, asio::buffer(msg));
});
}

// 发送ReadyForQuery
void PGSession::send_ready_for_query()
{
std::vector<char> ready = {'Z', 0, 0, 0, 5, 'I'};
asio::write(socket_, asio::buffer(ready));
boost::asio::post(get_io_context(),
[self = shared_from_this()]()
{
std::vector<char> ready = {'Z', 0, 0, 0, 5, 'I'};
asio::write(self->socket_, asio::buffer(ready));
});
}