diff --git a/.gitignore b/.gitignore index 05651421..369831a7 100644 --- a/.gitignore +++ b/.gitignore @@ -153,4 +153,5 @@ one_test_sample/ src/ ignore_files/ PySyft -*syft* \ No newline at end of file +*syft* +.inference diff --git a/Makefile b/Makefile index 3d403199..17ef2597 100644 --- a/Makefile +++ b/Makefile @@ -98,7 +98,7 @@ crypto_provider: python -m Node --id crypto_provider --port 8780 model_owner: - python -m Node --id crypto_provider --port 8771 + python -m Node --id model_owner --port 8771 inference_setup: make data_owner & make crypto_provider & make model_owner diff --git a/inference.py b/inference.py index c47d51db..4242c967 100644 --- a/inference.py +++ b/inference.py @@ -141,7 +141,9 @@ def __getitem__(self, idx): data_owner = sy.grid.clients.data_centric_fl_client.DataCentricFLClient( hook, "http://{:s}:{:s}".format( - worker_dict["data_owner"]["host"], worker_dict["data_owner"]["port"] + worker_dict["data_owner"]["host"], + worker_dict["data_owner"]["port"], + is_client_worker=True, ), ) if cmd_args.encrypted_inference: @@ -153,6 +155,7 @@ def __getitem__(self, idx): "http://{:s}:{:s}".format( worker_dict["crypto_provider"]["host"], worker_dict["crypto_provider"]["port"], + is_client_worker=True, ), ) model_owner = sy.grid.clients.data_centric_fl_client.DataCentricFLClient( @@ -160,6 +163,7 @@ def __getitem__(self, idx): "http://{:s}:{:s}".format( worker_dict["model_owner"]["host"], worker_dict["model_owner"]["port"], + is_client_worker=True, ), ) else: @@ -284,18 +288,12 @@ def __getitem__(self, idx): "protocol": "fss", "requires_grad": False, } - # model.send(model_owner) model.fix_precision(precision_fractional=4, dtype="long").share( - *workers, - crypto_provider=crypto_provider, - protocol="fss", - requires_grad=False + *workers, crypto_provider=crypto_provider, requires_grad=False ) # test method model.eval() total_pred, total_target, total_scores = [], [], [] - if args.encrypted_inference: - mean, std = mean.send(data_owner), std.send(data_owner) with torch.no_grad(): for i, data in tqdm( enumerate(dataset), @@ -322,8 +320,8 @@ def __getitem__(self, idx): .share( *workers, crypto_provider=crypto_provider, - protocol="fss", - requires_grad=False + requires_grad=False, + protocol="fss" ) .get() )