Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduced memory in benchmark #232

Merged
merged 8 commits into from
Jun 13, 2023
Merged
2 changes: 1 addition & 1 deletion packages/client/libclient-py/quickmpc/qmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .qmpc_server import QMPCServer
from .share import Share
from .utils.parse_csv import parse, parse_csv
from .utils.restore import restore
from .restore import restore

logger = logging.getLogger(__name__)
# qmpc.JobStatus でアクセスできるようにエイリアスを設定する
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import numpy as np
from natsort import natsorted

import google.protobuf.json_format
from .proto.common_types.common_types_pb2 import Schema
from .share import Share
from .utils.if_present import if_present


def get_meta(job_uuid: str, path: str):
file_name = glob.glob(f"{path}/dim?-{job_uuid}-*")[0]
Expand All @@ -28,30 +33,44 @@ def get_result(job_uuid: str, path: str, party: int):


def restore(job_uuid: str, path: str, party_size: int) -> Any:
schema = []
column_number = get_meta(job_uuid, path)

schema = [None]*column_number
result: Any = []

column_number = get_meta(job_uuid, path)
is_schema = True if len(
glob.glob(f"{path}/schema-{job_uuid}-*")) != 0 else False
is_dim2 = True if len(
glob.glob(f"{path}/dim2-{job_uuid}-*")) != 0 else False
if column_number == 0:
return [[]] if is_dim2 else []
if is_schema:
return {"schema": [], "table": [[]]}
elif is_dim2:
return [[]]
else:
return []

for party in range(party_size):
if party == 0:
for val in get_result(job_uuid, f"{path}/schema", party):
schema.append(val)
for i, val in enumerate(get_result(job_uuid, f"{path}/schema", party)):
col_sch = google.protobuf.json_format.Parse(
val, Schema())
schema[i] = col_sch

itr = 0
for val in get_result(job_uuid, f"{path}/dim?", party):
f = Share.get_pre_convert_func(schema[itr % column_number])
if itr >= len(result):
result.append(Decimal(val))
result.append(f(val))
else:
result[itr] += Decimal(val)
result[itr] += f(val)
itr += 1
result = np.array(result)\
.reshape(-1, column_number).tolist() if is_dim2 else result
if is_dim2 and len(result) == 0:
schema = []
result = [[]]

result_float = np.vectorize(float)(result)
result = np.array(result_float)\
.reshape(-1, column_number).tolist() if is_dim2 else result_float
result = {"schema": schema, "table": result} if len(schema) else result
results = if_present(result, Share.convert_type, schema)
result = {"schema": schema, "table": results} if is_schema else results
return result
14 changes: 11 additions & 3 deletions scripts/libclient/src/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import time

import glob
import os
from quickmpc import QMPC, JobStatus

qmpc: QMPC = QMPC([
Expand All @@ -9,9 +10,11 @@
"http://localhost:50003",
])

def __try_get_computation_result(job_uuid, is_limit):
def __try_get_computation_result(job_uuid, is_limit, path = './result'):
try:
get_res = qmpc.get_computation_result(job_uuid)
if not os.path.isdir(path):
os.mkdir(path)
get_res = qmpc.get_computation_result(job_uuid, path)
except:
if is_limit:
raise
Expand All @@ -23,7 +26,12 @@ def __try_get_computation_result(job_uuid, is_limit):
all_completed = all([status == JobStatus.COMPLETED
for status in get_res["statuses"]])

if glob.glob(f"{path}/schema*{job_uuid}-*") == 0:
return None
KotaTakahashi9320 marked this conversation as resolved.
Show resolved Hide resolved
if all_completed:
res = qmpc.restore(job_uuid, path)
print(job_uuid)
KotaTakahashi9320 marked this conversation as resolved.
Show resolved Hide resolved
get_res["results"] = res
# NOTE
# 計算結果取得時には含まれていないが,job_uuidがmodelの取得に使われるので
# このような処理をしている
Expand Down