Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop saving shares as strings #267

Merged
merged 12 commits into from
Jul 27, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class FixedPoint : private boost::operators<FixedPoint>
tmp *= shift;
value = static_cast<mp_int>(tmp);
}

FixedPoint(const std::string &str)
{
mp_float v_{str};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

namespace qmpc::ComputationToComputation
{

Server::Server() noexcept
{
Config *conf = Config::getInstance();
Expand All @@ -20,20 +21,6 @@ Server::Server() noexcept
}
}

static std::string share_to_str(const computationtocomputation::Shares_Share &share)
{
using cs = computationtocomputation::Shares_Share;
switch (share.value_case())
{
case (cs::ValueCase::kFlag):
return std::to_string(share.flag());
case (cs::ValueCase::kNum):
return std::to_string(share.num());
case (cs::ValueCase::kByte):
case (cs::ValueCase::VALUE_NOT_SET):
return share.byte();
}
}
// 複数シェアをexchangeする場合
grpc::Status Server::ExchangeShares(
grpc::ServerContext *context,
Expand All @@ -44,7 +31,7 @@ grpc::Status Server::ExchangeShares(
computationtocomputation::Shares multiple_shares;
bool first = true;
int party_id, share_id, job_id, thread_id;
std::vector<std::string> share_str_vec;
std::vector<CtoCShare> share_vec;

while (stream->Read(&multiple_shares))
{
Expand All @@ -63,80 +50,21 @@ grpc::Status Server::ExchangeShares(
for (int i = 0; i < multiple_shares.share_list_size(); i++)
{
auto share = multiple_shares.share_list(i);

auto share_str = share_to_str(share);
share_str_vec.emplace_back(share_str);
share_vec.emplace_back(share);
}
first = false;
}

std::lock_guard<std::mutex> lock(mtx); // mutex発動
if (!first)
{
shares_vec[std::make_tuple(party_id, share_id, job_id, thread_id)] = share_str_vec;
shares_vec[std::make_tuple(party_id, share_id, job_id, thread_id)] = share_vec;
}

cond.notify_all(); // 通知
return grpc::Status::OK;
}

// 単一シェアget用
std::string Server::getShare(int party_id, qmpc::Share::AddressId share_id)
{
Config *conf = Config::getInstance();
std::unique_lock<std::mutex> lock(mtx); // mutex発動
auto key = std::make_tuple(
party_id, share_id.getShareId(), share_id.getJobId(), share_id.getThreadId()
);
if (!cond.wait_for(
lock,
std::chrono::seconds(conf->getshare_time_limit),
[&] { return shares_vec.count(key) == 1; }
)) // 待機
{
qmpc::Log::throw_with_trace(std::runtime_error("getShare is timeout"));
}
auto share = shares_vec[key];
shares_vec.erase(key);
return share[0];
}

// 複数シェアget用
std::vector<std::string> Server::getShares(
int party_id, const std::vector<qmpc::Share::AddressId> &share_ids
)
{
const std::size_t length = share_ids.size();
if (length == 0)
{
return std::vector<std::string>{};
}

Config *conf = Config::getInstance();
// std::cout << "party share job thread"
// << " " << party_id << " " << share_ids[0].getShareId() << " "
// << share_ids[0].getJobId() << " " << share_ids[0].getThreadId() << std::endl;
std::vector<std::string> str_values;
str_values.reserve(length);
auto key = std::make_tuple(
party_id, share_ids[0].getShareId(), share_ids[0].getJobId(), share_ids[0].getThreadId()
);
std::unique_lock<std::mutex> lock(mtx); // mutex発動

if (!cond.wait_for(
lock,
std::chrono::seconds(conf->getshare_time_limit * length),
[&] { return shares_vec.count(key) == 1; }
)) // 待機
{
qmpc::Log::throw_with_trace(std::runtime_error("getShares is timeout"));
}
auto local_str_shares = shares_vec[key];
shares_vec.erase(key);
assert(local_str_shares.size() == length);
return local_str_shares;
}

void Server::runServer(std::string endpoint)
{
auto server = Server::getServer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "unistd.h"
namespace qmpc::ComputationToComputation
{
using CtoCShare = computationtocomputation::Shares_Share;

class Server final : public computationtocomputation::ComputationToComputation::Service
{
std::map<int, std::shared_ptr<Client>> ccClients;
Expand All @@ -39,10 +41,62 @@ class Server final : public computationtocomputation::ComputationToComputation::
google::protobuf::Empty *response
) override;
// 受け取ったシェアをgetするメソッド
std::string getShare(int party_id, qmpc::Share::AddressId share_id);
std::vector<std::string> getShares(
int party_id, const std::vector<qmpc::Share::AddressId> &share_ids
);
template <typename SV>
SV getShare(int party_id, qmpc::Share::AddressId share_id)
{
Config *conf = Config::getInstance();
std::unique_lock<std::mutex> lock(mtx); // mutex発動
auto key = std::make_tuple(
party_id, share_id.getShareId(), share_id.getJobId(), share_id.getThreadId()
);
if (!cond.wait_for(
lock,
std::chrono::seconds(conf->getshare_time_limit),
[&] { return shares_vec.count(key) == 1; }
)) // 待機
{
qmpc::Log::throw_with_trace(std::runtime_error("getShare is timeout"));
}
auto share = shares_vec[key][0];
shares_vec.erase(key);
return toSV<SV>(share);
}
template <typename SV>
std::vector<SV> getShares(int party_id, const std::vector<qmpc::Share::AddressId> &share_ids)
{
const std::size_t length = share_ids.size();
if (length == 0)
{
return std::vector<SV>{};
}

Config *conf = Config::getInstance();
// std::cout << "party share job thread"
// << " " << party_id << " " << share_ids[0].getShareId() << " "
// << share_ids[0].getJobId() << " " << share_ids[0].getThreadId() << std::endl;
auto key = std::make_tuple(
party_id, share_ids[0].getShareId(), share_ids[0].getJobId(), share_ids[0].getThreadId()
);
std::unique_lock<std::mutex> lock(mtx); // mutex発動

if (!cond.wait_for(
lock,
std::chrono::seconds(conf->getshare_time_limit * length),
[&] { return shares_vec.count(key) == 1; }
)) // 待機
{
qmpc::Log::throw_with_trace(std::runtime_error("getShares is timeout"));
}
auto local_str_shares = shares_vec[key];
shares_vec.erase(key);
assert(local_str_shares.size() == length);
std::vector<SV> ret(length);
for (size_t i = 0; i < length; i++)
{
ret[i] = toSV<SV>(local_str_shares[i]);
}
return ret;
}
Server(Server &&) noexcept = delete;
Server(const Server &) noexcept = delete;
Server &operator=(Server &&) noexcept = delete;
Expand Down Expand Up @@ -70,7 +124,27 @@ class Server final : public computationtocomputation::ComputationToComputation::
using address_type = std::tuple<int, int, unsigned int, int>;
// 受け取ったシェアを保存する変数
// party_id, share_idをキーとして保存
std::map<address_type, std::vector<std::string>> shares_vec;
std::map<address_type, std::vector<CtoCShare>> shares_vec;

template <typename SV>
SV toSV(const CtoCShare &share_value)
{
if constexpr (std::is_same_v<SV, bool>)
{
assert(share_value.has_flag());
return share_value.flag();
}
else if constexpr (std::is_integral_v<SV>)
{
assert(share_value.has_num());
return share_value.num();
}
else
{
assert(share_value.has_byte());
return SV(share_value.byte());
}
}
};

} // namespace qmpc::ComputationToComputation
25 changes: 4 additions & 21 deletions packages/server/computation_container/share/networking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,7 @@ void open(const T &share)
}
}

template <typename SV>
SV stosv(const std::string &str_value)
{
if constexpr (std::is_same_v<SV, bool>)
{
assert(str_value == "0" || str_value == "1");
return str_value == "1";
}
else if constexpr (std::is_integral_v<SV>)
{
return std::stoll(str_value);
}
else // TODO: constructable or convertible
{
return SV(str_value);
}
}
using CtoCShare = computationtocomputation::Shares_Share;

template <typename SV>
auto recons(const Share<SV> &share)
Expand All @@ -129,8 +113,7 @@ auto recons(const Share<SV> &share)
}
else
{
std::string s = server->getShare(pt_id, share.getId());
ret += stosv<SV>(s);
ret += server->getShare<SV>(pt_id, share.getId());
}
}
return ret;
Expand Down Expand Up @@ -162,10 +145,10 @@ auto recons(const std::vector<Share<SV>> &share)
}
else
{
std::vector<std::string> values = server->getShares(pt_id, ids_list);
std::vector<SV> values = server->getShares<SV>(pt_id, ids_list);
for (unsigned int i = 0; i < length; i++)
{
ret[i] += stosv<SV>(values[i]);
ret[i] += values[i];
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion packages/server/computation_container/test/unit_test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ cc_test(
"//config_parse:config_parse",
"@proto//computation_to_computation_container:computation_to_computation_cc_grpc",
"//share:address",
"//logging:log"
"//logging:log",
"//fixed_point:fixed_point"
],
linkopts = ["-pthread"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "config_parse/config_parse.hpp"
#include "external/proto/computation_to_computation_container/computation_to_computation.grpc.pb.h"
#include "fixed_point/fixed_point.hpp"
#include "gtest/gtest.h"
#include "logging/logger.hpp"
#include "server/computation_to_computation_container/server.hpp"
Expand Down Expand Up @@ -63,7 +64,7 @@ TEST(CtoC_Test, EXCHANGESHARE)
EXPECT_TRUE(status.ok());

auto server = qmpc::ComputationToComputation::Server::getServer();
auto data = server->getShare(conf->party_id, share_id[0]);
std::string data = server->getShare<std::string>(conf->party_id, share_id[0]);
EXPECT_EQ(value, data);
}
TEST(CtoC_Test, EXCHANGESHARES)
Expand Down Expand Up @@ -96,7 +97,7 @@ TEST(CtoC_Test, EXCHANGESHARES)
EXPECT_TRUE(status.ok());

auto server = qmpc::ComputationToComputation::Server::getServer();
auto datas = server->getShares(conf->party_id, share_ids);
std::vector<std::string> datas = server->getShares<std::string>(conf->party_id, share_ids);
for (unsigned int i = 0; i < length; i++)
{
EXPECT_EQ(values[i], datas[i]);
Expand All @@ -109,7 +110,7 @@ TEST(CtoC_Test, GetShareThrowExceptionTest)
qmpc::Share::AddressId share_id;

auto server = qmpc::ComputationToComputation::Server::getServer();
EXPECT_ANY_THROW(server->getShare(conf->party_id, share_id));
EXPECT_ANY_THROW(server->getShare<std::string>(conf->party_id, share_id));
}

TEST(CtoC_Test, GetSharesThrowExceptionTest)
Expand All @@ -119,7 +120,7 @@ TEST(CtoC_Test, GetSharesThrowExceptionTest)
std::vector<qmpc::Share::AddressId> share_ids(length);

auto server = qmpc::ComputationToComputation::Server::getServer();
EXPECT_ANY_THROW(server->getShares(conf->party_id, share_ids));
EXPECT_ANY_THROW(server->getShares<std::string>(conf->party_id, share_ids));
}
int main(int argc, char **argv)
{
Expand Down
Loading