diff --git a/Makefile b/Makefile index c0e48a96..13c8a2cf 100644 --- a/Makefile +++ b/Makefile @@ -49,7 +49,7 @@ federated_insecure: @echo Finished Training on VirtualWorkers without SecAgg federated_gridnode_secure: - python train.py --config configs/torch/pneumonia-resnet-pretrained.ini --train_federated --data_dir data/server_simulation --websockets + python train.py --config configs/torch/pneumonia-resnet-pretrained.ini --train_federated --websockets --data_dir data/server_simulation federated_gridnode_insecure: python train.py --config configs/torch/pneumonia-resnet-pretrained.ini --train_federated --data_dir data/server_simulation --websockets --unencrypted_aggregation diff --git a/inference.py b/inference.py index e0246e5e..8a4afb36 100644 --- a/inference.py +++ b/inference.py @@ -63,7 +63,7 @@ ) parser.add_argument("--cuda", action="store_true", help="Use CUDA acceleration.") parser.add_argument( - "--http_protocol", action="store_true", help="Use HTTP instead of WS." + "--http_protocol", action="store_true", help="Use HTTP only instead of WS." ) cmd_args = parser.parse_args() @@ -73,7 +73,8 @@ if not cmd_args.http_protocol: warn( - "Under certain circumstances, WebSockets can fail when performing encrypted inference. If you experience errors related to 'rsv' not being implemented, consider enabling HTTP." + """Under certain circumstances, WebSockets can fail when performing encrypted inference. + If you experience WebSocket-related errors, consider using HTTP only with the --http_protocol flag.""" ) args = state["args"] @@ -111,14 +112,16 @@ assert ( "crypto_provider" in worker_names ), "No crypto_provider in websockets config" - crypto_provider = sy.grid.clients.data_centric_fl_client.DataCentricFLClient( - hook, - "{:s}://{:s}:{:s}".format( - "http" if cmd_args.http_protocol else "ws", - worker_dict["crypto_provider"]["host"], - worker_dict["crypto_provider"]["port"], - ), - http_protocol=cmd_args.http_protocol, + crypto_provider = ( + sy.grid.clients.data_centric_fl_client.DataCentricFLClient( + hook, + "{:s}://{:s}:{:s}".format( + "http" if cmd_args.http_protocol else "ws", + worker_dict["crypto_provider"]["host"], + worker_dict["crypto_provider"]["port"], + ), + http_protocol=cmd_args.http_protocol, + ) ) model_owner = sy.grid.clients.data_centric_fl_client.DataCentricFLClient( hook, @@ -188,7 +191,11 @@ if not args.pretrained: loader.change_channels(1) if not cmd_args.websockets_config: - dataset = PathDataset(cmd_args.data_dir, transform=tf, loader=loader,) + dataset = PathDataset( + cmd_args.data_dir, + transform=tf, + loader=loader, + ) if cmd_args.encrypted_inference: data = [] for d in tqdm(dataset, total=len(dataset), leave=False, desc="load data"): @@ -251,7 +258,7 @@ model.load_state_dict(state["model_state_dict"]) model.to(device) if args.encrypted_inference: - fix_prec_kwargs = {"precision_fractional": 4, "dtype": "long"} + fix_prec_kwargs = {"precision_fractional": 16, "dtype": "long"} share_kwargs = { "crypto_provider": crypto_provider, "protocol": "fss", @@ -296,4 +303,3 @@ sys.stdout.write(json.dumps(pred_dict)) print("\n{:s}".format(str(Counter(total_pred)))) - diff --git a/requirements.txt b/requirements.txt index 55b6212a..56ee2b7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ sqlalchemy_utils ==0.36.8 wsaccel ==0.6.2 papermill ==2.0.0 bcrypt ==3.2.0 -syft==0.2.9 \ No newline at end of file +#syft==0.2.9 \ No newline at end of file