Skip to content

Commit

Permalink
makefile and inference.py corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
gkaissis committed Aug 27, 2020
1 parent 142e6cc commit 51ad9d7
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,5 @@ one_test_sample/
src/
ignore_files/
PySyft
*syft*
*syft*
.inference
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ crypto_provider:
python -m Node --id crypto_provider --port 8780

model_owner:
python -m Node --id crypto_provider --port 8771
python -m Node --id model_owner --port 8771

inference_setup:
make data_owner & make crypto_provider & make model_owner
Expand Down
18 changes: 8 additions & 10 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def __getitem__(self, idx):
data_owner = sy.grid.clients.data_centric_fl_client.DataCentricFLClient(
hook,
"http://{:s}:{:s}".format(
worker_dict["data_owner"]["host"], worker_dict["data_owner"]["port"]
worker_dict["data_owner"]["host"],
worker_dict["data_owner"]["port"],
is_client_worker=True,
),
)
if cmd_args.encrypted_inference:
Expand All @@ -153,13 +155,15 @@ def __getitem__(self, idx):
"http://{:s}:{:s}".format(
worker_dict["crypto_provider"]["host"],
worker_dict["crypto_provider"]["port"],
is_client_worker=True,
),
)
model_owner = sy.grid.clients.data_centric_fl_client.DataCentricFLClient(
hook,
"http://{:s}:{:s}".format(
worker_dict["model_owner"]["host"],
worker_dict["model_owner"]["port"],
is_client_worker=True,
),
)
else:
Expand Down Expand Up @@ -284,18 +288,12 @@ def __getitem__(self, idx):
"protocol": "fss",
"requires_grad": False,
}
# model.send(model_owner)
model.fix_precision(precision_fractional=4, dtype="long").share(
*workers,
crypto_provider=crypto_provider,
protocol="fss",
requires_grad=False
*workers, crypto_provider=crypto_provider, requires_grad=False
)
# test method
model.eval()
total_pred, total_target, total_scores = [], [], []
if args.encrypted_inference:
mean, std = mean.send(data_owner), std.send(data_owner)
with torch.no_grad():
for i, data in tqdm(
enumerate(dataset),
Expand All @@ -322,8 +320,8 @@ def __getitem__(self, idx):
.share(
*workers,
crypto_provider=crypto_provider,
protocol="fss",
requires_grad=False
requires_grad=False,
protocol="fss"
)
.get()
)
Expand Down

0 comments on commit 51ad9d7

Please sign in to comment.