Skip to content

Commit

Permalink
Merge pull request #242 from acompany-develop/feature/nakata/add_addr…
Browse files Browse the repository at this point in the history
…equest

Add AddShareDataFrame request
  • Loading branch information
mdonaka authored Jun 30, 2023
2 parents c1068b7 + 39907ea commit 2eda104
Show file tree
Hide file tree
Showing 23 changed files with 1,114 additions and 203 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions packages/client/libclient-py/quickmpc/proto/libc_to_manage_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,39 @@ class GetJobErrorInfoResponse(google.protobuf.message.Message):
def WhichOneof(self, oneof_group: typing_extensions.Literal["_job_error_info", b"_job_error_info"]) -> typing_extensions.Literal["job_error_info"] | None: ...

global___GetJobErrorInfoResponse = GetJobErrorInfoResponse

@typing_extensions.final
class AddShareDataFrameRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

BASE_DATA_ID_FIELD_NUMBER: builtins.int
ADD_DATA_ID_FIELD_NUMBER: builtins.int
TOKEN_FIELD_NUMBER: builtins.int
base_data_id: builtins.str
add_data_id: builtins.str
token: builtins.str
def __init__(
self,
*,
base_data_id: builtins.str = ...,
add_data_id: builtins.str = ...,
token: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["add_data_id", b"add_data_id", "base_data_id", b"base_data_id", "token", b"token"]) -> None: ...

global___AddShareDataFrameRequest = AddShareDataFrameRequest

@typing_extensions.final
class AddShareDataFrameResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

DATA_ID_FIELD_NUMBER: builtins.int
data_id: builtins.str
def __init__(
self,
*,
data_id: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["data_id", b"data_id"]) -> None: ...

global___AddShareDataFrameResponse = AddShareDataFrameResponse
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def __init__(self, channel):
request_serializer=libc__to__manage__pb2.GetJobErrorInfoRequest.SerializeToString,
response_deserializer=libc__to__manage__pb2.GetJobErrorInfoResponse.FromString,
)
self.AddShareDataFrame = channel.unary_unary(
'/libctomanage.LibcToManage/AddShareDataFrame',
request_serializer=libc__to__manage__pb2.AddShareDataFrameRequest.SerializeToString,
response_deserializer=libc__to__manage__pb2.AddShareDataFrameResponse.FromString,
)


class LibcToManageServicer(object):
Expand Down Expand Up @@ -112,6 +117,12 @@ def GetJobErrorInfo(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def AddShareDataFrame(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_LibcToManageServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -155,6 +166,11 @@ def add_LibcToManageServicer_to_server(servicer, server):
request_deserializer=libc__to__manage__pb2.GetJobErrorInfoRequest.FromString,
response_serializer=libc__to__manage__pb2.GetJobErrorInfoResponse.SerializeToString,
),
'AddShareDataFrame': grpc.unary_unary_rpc_method_handler(
servicer.AddShareDataFrame,
request_deserializer=libc__to__manage__pb2.AddShareDataFrameRequest.FromString,
response_serializer=libc__to__manage__pb2.AddShareDataFrameResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'libctomanage.LibcToManage', rpc_method_handlers)
Expand Down Expand Up @@ -303,3 +319,20 @@ def GetJobErrorInfo(request,
libc__to__manage__pb2.GetJobErrorInfoResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def AddShareDataFrame(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/libctomanage.LibcToManage/AddShareDataFrame',
libc__to__manage__pb2.AddShareDataFrameRequest.SerializeToString,
libc__to__manage__pb2.AddShareDataFrameResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class LibcToManageStub:
libc_to_manage_pb2.GetJobErrorInfoRequest,
libc_to_manage_pb2.GetJobErrorInfoResponse,
]
AddShareDataFrame: grpc.UnaryUnaryMultiCallable[
libc_to_manage_pb2.AddShareDataFrameRequest,
libc_to_manage_pb2.AddShareDataFrameResponse,
]

class LibcToManageServicer(metaclass=abc.ABCMeta):
"""*
Expand Down Expand Up @@ -100,5 +104,11 @@ class LibcToManageServicer(metaclass=abc.ABCMeta):
request: libc_to_manage_pb2.GetJobErrorInfoRequest,
context: grpc.ServicerContext,
) -> libc_to_manage_pb2.GetJobErrorInfoResponse: ...
@abc.abstractmethod
def AddShareDataFrame(
self,
request: libc_to_manage_pb2.AddShareDataFrameRequest,
context: grpc.ServicerContext,
) -> libc_to_manage_pb2.AddShareDataFrameResponse: ...

def add_LibcToManageServicer_to_server(servicer: LibcToManageServicer, server: grpc.Server) -> None: ...
2 changes: 0 additions & 2 deletions packages/client/libclient-py/quickmpc/qmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ class QMPC:
def __post_init__(self, endpoints: List[str],
retry_num: int, retry_wait_time: int):
logger.info(f"[QuickMPC server IP]={endpoints}")
object.__setattr__(self, "_QMPC__qmpc_request", QMPCRequest(
endpoints, retry_num, retry_wait_time))
object.__setattr__(self, "_QMPC__qmpc_request", QMPCRequest(
endpoints, retry_num, retry_wait_time))
object.__setattr__(self, "_QMPC__party_size", len(endpoints))
Expand Down
28 changes: 23 additions & 5 deletions packages/client/libclient-py/quickmpc/qmpc_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from .proto.common_types.common_types_pb2 import (ComputationMethod,
JobErrorInfo, JobStatus,
Schema)
from .proto.libc_to_manage_pb2 import (DeleteSharesRequest,
from .proto.libc_to_manage_pb2 import (AddShareDataFrameRequest,
DeleteSharesRequest,
ExecuteComputationRequest,
GetComputationResultRequest,
GetComputationResultResponse,
Expand All @@ -32,10 +33,10 @@
JoinOrder, SendSharesRequest)
from .proto.libc_to_manage_pb2_grpc import LibcToManageStub
from .request.qmpc_request_interface import QMPCRequestInterface
from .request.response import (DeleteShareResponse, ExecuteResponse,
GetDataListResponse, GetElapsedTimeResponse,
GetJobErrorInfoResponse, GetResultResponse,
SendShareResponse)
from .request.response import (AddShareDataFrameResponse, DeleteShareResponse,
ExecuteResponse, GetDataListResponse,
GetElapsedTimeResponse, GetJobErrorInfoResponse,
GetResultResponse, SendShareResponse)
from .request.status import Status
from .share import Share
from .utils.if_present import if_present
Expand Down Expand Up @@ -458,3 +459,20 @@ def delete_share(self, data_ids: List[str]) -> DeleteShareResponse:
if is_ok:
return DeleteShareResponse(Status.OK)
return DeleteShareResponse(Status.BadGateway)

def add_share_data_frame(self, base_data_id: str, add_data_id: str) \
-> AddShareDataFrameResponse:
req = AddShareDataFrameRequest(base_data_id=base_data_id,
add_data_id=add_data_id,
token=self.__token)
# 非同期にリクエスト送信
with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.__retry,
stub.AddShareDataFrame, req)
for stub in self.__client_stubs]
is_ok, response = QMPCRequest.__futures_result(futures)
data_id = response[0].data_id if is_ok else ""
# TODO: __futures_resultの返り値を適切なものに変更する
if is_ok:
return AddShareDataFrameResponse(Status.OK, data_id)
return AddShareDataFrameResponse(Status.BadGateway, data_id)
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import pandas as pd

from .response import (DeleteShareResponse, ExecuteResponse,
GetDataListResponse, GetElapsedTimeResponse,
GetJobErrorInfoResponse, GetResultResponse,
SendShareResponse)
from .response import (AddShareDataFrameResponse, DeleteShareResponse,
ExecuteResponse, GetDataListResponse,
GetElapsedTimeResponse, GetJobErrorInfoResponse,
GetResultResponse, SendShareResponse)


class QMPCRequestInterface(ABC):
Expand Down Expand Up @@ -60,3 +60,7 @@ def get_data_list(self) -> GetDataListResponse: ...

@abstractmethod
def delete_share(self, data_ids: List[str]) -> DeleteShareResponse: ...

@abstractmethod
def add_share_data_frame(self, base_data_id: str, add_data_id: str) \
-> AddShareDataFrameResponse: ...
6 changes: 6 additions & 0 deletions packages/client/libclient-py/quickmpc/request/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,9 @@ class GetDataListResponse():
@dataclass(frozen=True)
class DeleteShareResponse():
status: Status


@dataclass(frozen=True)
class AddShareDataFrameResponse():
status: Status
data_id: str
18 changes: 18 additions & 0 deletions packages/client/libclient-py/quickmpc/share_data_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,24 @@ def _wait_execute(self):
break
self.__status = ShareDataFrameStatus.OK

def __add__(self, other: "ShareDataFrame") -> "ShareDataFrame":
"""テーブルを加算する.
qmpc.send_toで送ったデータでかつ,行数,列数が一致している場合のみ正常に動作する.
Parameters
----------
other: ShareDataFrame
結合したいDataFrame
Returns
----------
Result
加算して得られたDataFrameのResult
"""
res = self.__qmpc_request.add_share_data_frame(self.__id, other.__id)
return ShareDataFrame(res.data_id, self.__qmpc_request)

@methoddispatch()
def join(self, other: "ShareDataFrame", *, debug_mode=False) \
-> "ShareDataFrame":
Expand Down
6 changes: 6 additions & 0 deletions packages/client/libclient-py/tests/unit_tests/local_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def GetSchema(self, request, context):
res = libc_to_manage_pb2.GetSchemaResponse()
return res

def AddShareDataFrame(self, request, context):
res = libc_to_manage_pb2.AddShareDataFrameResponse(
data_id="data_id"
)
return res


def serve(num: int):
""" server setting """
Expand Down
Loading

0 comments on commit 2eda104

Please sign in to comment.