Skip to content

Commit

Permalink
Merge pull request #127 from acompany-develop/feature/ichikawa/proto
Browse files Browse the repository at this point in the history
Update Mono-Address Networking
  • Loading branch information
arukuka authored Mar 9, 2023
2 parents a861246 + fb277de commit e719675
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 149 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ bazel-testlogs
bazel-tools
.bazel
result

proto/bazel-proto
**/vendor

**/db/**/data/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,44 +27,6 @@ class Client
Client(const Url &endpoint) noexcept;
Client &operator=(Client &&) noexcept = default;
static std::shared_ptr<Client> getPtr(const Url &);
// 単一シェアを生成する場合
template <typename T>
computationtocomputation::Share makeShare(
T &&value, qmpc::Share::AddressId share_id, int party_id
) const
{
computationtocomputation::Share s;
auto a = s.mutable_address_id();
a->set_share_id(share_id.getShareId());
a->set_job_id(share_id.getJobId());

if constexpr (std::is_same_v<std::decay_t<T>, bool>)
{
s.set_flag(value);
}
else if constexpr (std::is_same_v<std::decay_t<T>, int>)
{
s.set_num(value);
}
else if constexpr (std::is_same_v<std::decay_t<T>, long>)
{
s.set_num64(value);
}
else if constexpr (std::is_same_v<std::decay_t<T>, float>)
{
s.set_f(value);
}
else if constexpr (std::is_same_v<std::decay_t<T>, double>)
{
s.set_d(value);
}
else
{
s.set_byte(to_string(value));
}
s.set_party_id(party_id);
return s;
}

// 複数シェアを生成する場合
// 約1mbごとに分割して生成する
Expand All @@ -78,27 +40,26 @@ class Client
{
std::vector<computationtocomputation::Shares> share_vec;
share_vec.reserve(length);
size_t addressId_size = sizeof(qmpc::Share::AddressId);
size_t size = 0;
computationtocomputation::Shares s;
auto a = s.mutable_address_id();
a->set_share_id(share_ids[0].getShareId());
a->set_job_id(share_ids[0].getJobId());
a->set_party_id(party_id);
// std::cout << "Client party , job, share" << party_id << " " << share_ids[0].getJobId()
// << " " << share_ids[0].getShareId() << std::endl;
for (unsigned int i = 0; i < length; i++)
{
// string型のバイト数の取得
size_t value_size = sizeof(values[i]);
// ShareId,JobId,ThreadId,PartyIdの16byte
if (size + value_size + addressId_size + sizeof(s.party_id()) > 1000000)
if (size + value_size > 1000000)
{
size = 0;
s.set_party_id(party_id);
share_vec.push_back(s);
s = computationtocomputation::Shares{};
}
// 一つのsharesにつきPartyIdは一つだけなので分割しない際はShareId,JobId,ThreadIdの12byte
size = size + value_size + addressId_size;
size = size + value_size;
computationtocomputation::Shares_Share *multiple_shares = s.add_share_list();
auto a = multiple_shares->mutable_address_id();
a->set_share_id(share_ids[i].getShareId());
a->set_job_id(share_ids[i].getJobId());
if constexpr (std::is_same_v<std::decay_t<T>, bool>)
{
multiple_shares->set_flag(values[i]);
Expand All @@ -124,7 +85,6 @@ class Client
multiple_shares->set_byte(to_string(values[i]));
}
}
s.set_party_id(party_id);
share_vec.push_back(s);
return share_vec;
}
Expand All @@ -136,22 +96,12 @@ class Client
template <typename T>
bool exchangeShare(T &&value, qmpc::Share::AddressId share_id, int party_id) const
{
// リクエスト設定
computationtocomputation::Share share;
google::protobuf::Empty response;
share = makeShare(value, share_id, party_id);
grpc::Status status;
std::vector<std::decay_t<T>> values;
values.emplace_back(value);
std::vector<qmpc::Share::AddressId> share_ids;
share_ids.emplace_back(share_id);

// リトライポリシーに従ってリクエストを送る
auto retry_manager = RetryManager("CC", "exchangeShare");
do
{
grpc::ClientContext context;
status = stub_->ExchangeShare(&context, share, &response);
} while (retry_manager.retry(status));

// 送信に成功
return true;
return exchangeShares(values, share_ids, 1, party_id);
}

// 複数シェアを一括exchangeする場合
Expand All @@ -168,15 +118,14 @@ class Client
google::protobuf::Empty response;
shares = makeShares(values, share_ids, length, party_id);
grpc::Status status;

// リトライポリシーに従ってリクエストを送る
auto retry_manager = RetryManager("CC", "exchangeShares");
grpc::ClientContext context;
std::shared_ptr<grpc::ClientWriter<computationtocomputation::Shares>> stream(
stub_->ExchangeShares(&context, &response)
);
do
{
grpc::ClientContext context;
std::shared_ptr<grpc::ClientWriter<computationtocomputation::Shares>> stream(
stub_->ExchangeShares(&context, &response)
);
for (size_t i = 0; i < shares.size(); i++)
{
if (!stream->Write(shares[i]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,65 +72,68 @@ grpc::Status Server::ExchangeShare(
return grpc::Status::OK;
}

static std::string share_to_str(const computationtocomputation::Shares_Share &share)
{
using cs = computationtocomputation::Shares_Share;
switch (share.value_case())
{
case (cs::ValueCase::kFlag):
return std::to_string(share.flag());
case (cs::ValueCase::kNum):
return std::to_string(share.num());
case (cs::ValueCase::kNum64):
return std::to_string(share.num64());
case (cs::ValueCase::kF):
return std::to_string(share.f());
case (cs::ValueCase::kD):
return std::to_string(share.d());
case (cs::ValueCase::kByte):
case (cs::ValueCase::VALUE_NOT_SET):
return share.byte();
}
}
// 複数シェアをexchangeする場合
grpc::Status Server::ExchangeShares(
grpc::ServerContext *context,
grpc::ServerReader<computationtocomputation::Shares> *stream,
google::protobuf::Empty *response
)
{
std::lock_guard<std::mutex> lock(mtx); // mutex発動

using cs = computationtocomputation::Shares_Share;
computationtocomputation::Shares multiple_shares;
bool first = true;
int party_id, share_id, job_id, thread_id;
std::vector<std::string> share_str_vec;

while (stream->Read(&multiple_shares))
{
int party_id = multiple_shares.party_id();
if (first)
{
auto address = multiple_shares.address_id();
party_id = address.party_id();
share_id = address.share_id();
job_id = address.job_id();
thread_id = address.thread_id();
}
// std::cout << "party_id is " << party_id << std::endl;
// std::cout << "share_id is " << share_id << std::endl;
// std::cout << "job_id is " << job_id << std::endl;
// std::cout << "thread_id is " << thread_id << std::endl;
for (int i = 0; i < multiple_shares.share_list_size(); i++)
{
auto share = multiple_shares.share_list(i);
auto address = share.address_id();
switch (share.value_case())
{
case (cs::ValueCase::kFlag):
shares[std::make_tuple(
party_id, address.share_id(), address.job_id(), address.thread_id()
)] = std::to_string(share.flag());
break;
case (cs::ValueCase::kNum):
shares[std::make_tuple(
party_id, address.share_id(), address.job_id(), address.thread_id()
)] = std::to_string(share.num());
break;
case (cs::ValueCase::kNum64):
shares[std::make_tuple(
party_id, address.share_id(), address.job_id(), address.thread_id()
)] = std::to_string(share.num64());
break;
case (cs::ValueCase::kF):
shares[std::make_tuple(
party_id, address.share_id(), address.job_id(), address.thread_id()
)] = std::to_string(share.f());
break;
case (cs::ValueCase::kD):
shares[std::make_tuple(
party_id, address.share_id(), address.job_id(), address.thread_id()
)] = std::to_string(share.d());
break;
case (cs::ValueCase::kByte):
case (cs::ValueCase::VALUE_NOT_SET):
shares[std::make_tuple(
party_id, address.share_id(), address.job_id(), address.thread_id()
)] = share.byte();
break;
}

auto share_str = share_to_str(share);
share_str_vec.emplace_back(share_str);
}
first = false;
}

std::lock_guard<std::mutex> lock(mtx); // mutex発動
shares_vec[std::make_tuple(party_id, share_id, job_id, thread_id)] = share_str_vec;

cond.notify_all(); // 通知
return grpc::Status::OK;
}

// 単一シェアget用
std::string Server::getShare(int party_id, qmpc::Share::AddressId share_id)
{
Expand All @@ -142,14 +145,14 @@ std::string Server::getShare(int party_id, qmpc::Share::AddressId share_id)
if (!cond.wait_for(
lock,
std::chrono::seconds(conf->getshare_time_limit),
[&] { return shares.count(key) == 1; }
[&] { return shares_vec.count(key) == 1; }
)) // 待機
{
qmpc::Log::throw_with_trace(std::runtime_error("getShare is timeout"));
}
std::string share = shares[key];
shares.erase(key);
return share;
auto share = shares_vec[key];
shares_vec.erase(key);
return share[0];
}

// 複数シェアget用
Expand All @@ -158,27 +161,27 @@ std::vector<std::string> Server::getShares(
)
{
Config *conf = Config::getInstance();
// std::cout << "party share job thread"
// << " " << party_id << " " << share_ids[0].getShareId() << " "
// << share_ids[0].getJobId() << " " << share_ids[0].getThreadId() << std::endl;
std::vector<std::string> str_values;
str_values.reserve(length);
for (unsigned int i = 0; i < length; i++)
auto key = std::make_tuple(
party_id, share_ids[0].getShareId(), share_ids[0].getJobId(), share_ids[0].getThreadId()
);
std::unique_lock<std::mutex> lock(mtx); // mutex発動

if (!cond.wait_for(
lock,
std::chrono::seconds(conf->getshare_time_limit),
[&] { return shares_vec.count(key) == 1; }
)) // 待機
{
std::unique_lock<std::mutex> lock(mtx); // mutex発動
auto key = std::make_tuple(
party_id, share_ids[i].getShareId(), share_ids[i].getJobId(), share_ids[i].getThreadId()
);
if (!cond.wait_for(
lock,
std::chrono::seconds(conf->getshare_time_limit),
[&] { return shares.count(key) == 1; }
)) // 待機
{
qmpc::Log::throw_with_trace(std::runtime_error("getShares is timeout"));
}
std::string share = shares[key];
str_values.emplace_back(share);
shares.erase(key);
qmpc::Log::throw_with_trace(std::runtime_error("getShares is timeout"));
}
return str_values;
auto local_str_shares = shares_vec[key];
shares_vec.erase(key);
return local_str_shares;
}

void Server::runServer(std::string endpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class Server final : public computationtocomputation::ComputationToComputation::
// 受け取ったシェアを保存する変数
// party_id, share_idをキーとして保存
std::map<address, std::string> shares;
std::map<address, std::vector<std::string>> shares_vec;
};

} // namespace qmpc::ComputationToComputation
5 changes: 4 additions & 1 deletion packages/server/computation_container/share/networking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ void send(const T &shares, const int &pt_id)
}
else if constexpr (is_share<T>::value)
{
using SV = typename T::value_type;
// pt_idで指定されたパーティにシェアの値を送信する
client->exchangeShare(shares.getVal(), shares.getId(), conf->party_id);
std::vector<SV> str_value = {shares.getVal()};
std::vector<qmpc::Share::AddressId> share_id = {shares.getId()};
client->exchangeShares(str_value, share_id, 1, conf->party_id);
}
}
template <class SV>
Expand Down
Loading

0 comments on commit e719675

Please sign in to comment.