Skip to content

Commit

Permalink
improve http documentation, fix makefile and reqs
Browse files Browse the repository at this point in the history
  • Loading branch information
Georgios Kaissis committed Sep 21, 2020
1 parent de3db4c commit 042dec5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ federated_insecure:
@echo Finished Training on VirtualWorkers without SecAgg

federated_gridnode_secure:
python train.py --config configs/torch/pneumonia-resnet-pretrained.ini --train_federated --data_dir data/server_simulation --websockets
python train.py --config configs/torch/pneumonia-resnet-pretrained.ini --train_federated --websockets --data_dir data/server_simulation

federated_gridnode_insecure:
python train.py --config configs/torch/pneumonia-resnet-pretrained.ini --train_federated --data_dir data/server_simulation --websockets --unencrypted_aggregation
Expand Down
32 changes: 19 additions & 13 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
)
parser.add_argument("--cuda", action="store_true", help="Use CUDA acceleration.")
parser.add_argument(
"--http_protocol", action="store_true", help="Use HTTP instead of WS."
"--http_protocol", action="store_true", help="Use HTTP only instead of WS."
)
cmd_args = parser.parse_args()

Expand All @@ -73,7 +73,8 @@

if not cmd_args.http_protocol:
warn(
"Under certain circumstances, WebSockets can fail when performing encrypted inference. If you experience errors related to 'rsv' not being implemented, consider enabling HTTP."
"""Under certain circumstances, WebSockets can fail when performing encrypted inference.
If you experience WebSocket-related errors, consider using HTTP only with the --http_protocol flag."""
)

args = state["args"]
Expand Down Expand Up @@ -111,14 +112,16 @@
assert (
"crypto_provider" in worker_names
), "No crypto_provider in websockets config"
crypto_provider = sy.grid.clients.data_centric_fl_client.DataCentricFLClient(
hook,
"{:s}://{:s}:{:s}".format(
"http" if cmd_args.http_protocol else "ws",
worker_dict["crypto_provider"]["host"],
worker_dict["crypto_provider"]["port"],
),
http_protocol=cmd_args.http_protocol,
crypto_provider = (
sy.grid.clients.data_centric_fl_client.DataCentricFLClient(
hook,
"{:s}://{:s}:{:s}".format(
"http" if cmd_args.http_protocol else "ws",
worker_dict["crypto_provider"]["host"],
worker_dict["crypto_provider"]["port"],
),
http_protocol=cmd_args.http_protocol,
)
)
model_owner = sy.grid.clients.data_centric_fl_client.DataCentricFLClient(
hook,
Expand Down Expand Up @@ -188,7 +191,11 @@
if not args.pretrained:
loader.change_channels(1)
if not cmd_args.websockets_config:
dataset = PathDataset(cmd_args.data_dir, transform=tf, loader=loader,)
dataset = PathDataset(
cmd_args.data_dir,
transform=tf,
loader=loader,
)
if cmd_args.encrypted_inference:
data = []
for d in tqdm(dataset, total=len(dataset), leave=False, desc="load data"):
Expand Down Expand Up @@ -251,7 +258,7 @@
model.load_state_dict(state["model_state_dict"])
model.to(device)
if args.encrypted_inference:
fix_prec_kwargs = {"precision_fractional": 4, "dtype": "long"}
fix_prec_kwargs = {"precision_fractional": 16, "dtype": "long"}
share_kwargs = {
"crypto_provider": crypto_provider,
"protocol": "fss",
Expand Down Expand Up @@ -296,4 +303,3 @@
sys.stdout.write(json.dumps(pred_dict))

print("\n{:s}".format(str(Counter(total_pred))))

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ sqlalchemy_utils ==0.36.8
wsaccel ==0.6.2
papermill ==2.0.0
bcrypt ==3.2.0
syft==0.2.9
#syft==0.2.9

0 comments on commit 042dec5

Please sign in to comment.