diff --git a/examples/custom-metrics/client.py b/examples/custom-metrics/client.py index 09e786a0cfac..d0230e455477 100644 --- a/examples/custom-metrics/client.py +++ b/examples/custom-metrics/client.py @@ -68,4 +68,6 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=FlowerClient().to_client() +) diff --git a/examples/embedded-devices/client_pytorch.py b/examples/embedded-devices/client_pytorch.py index 134b573f7608..f326db7c678c 100644 --- a/examples/embedded-devices/client_pytorch.py +++ b/examples/embedded-devices/client_pytorch.py @@ -30,12 +30,13 @@ "--mnist", action="store_true", help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use " - "MNIST", + "MNIST", ) warnings.filterwarnings("ignore", category=UserWarning) NUM_CLIENTS = 50 + class Net(nn.Module): """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz').""" @@ -96,6 +97,7 @@ def prepare_dataset(use_mnist: bool): img_key = "img" norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) pytorch_transforms = Compose([ToTensor(), norm]) + def apply_transforms(batch): """Apply transforms to the partition from FederatedDataset.""" batch[img_key] = [pytorch_transforms(img) for img in batch[img_key]] @@ -112,7 +114,7 @@ def apply_transforms(batch): validsets.append(partition["test"]) testset = fds.load_full("test") testset = testset.with_transform(apply_transforms) - return trainsets, validsets, testset + return trainsets, validsets, testset # Flower client, adapted from Pytorch quickstart/simulation example diff --git a/examples/embedded-devices/client_tf.py b/examples/embedded-devices/client_tf.py index 712068af5fc4..ae793ecd81e0 100644 --- a/examples/embedded-devices/client_tf.py +++ b/examples/embedded-devices/client_tf.py @@ -45,8 +45,10 @@ def prepare_dataset(use_mnist: bool): partition.set_format("numpy") # Divide data on each node: 90% train, 10% test partition = partition.train_test_split(test_size=0.1) - x_train, y_train = partition["train"][img_key] / 255.0, partition["train"][ - "label"] + x_train, y_train = ( + partition["train"][img_key] / 255.0, + partition["train"]["label"], + ) x_test, y_test = partition["test"][img_key] / 255.0, partition["test"]["label"] partitions.append(((x_train, y_train), (x_test, y_test))) data_centralized = fds.load_full("test") @@ -123,7 +125,9 @@ def main(): # Start Flower client setting its associated data partition fl.client.start_client( server_address=args.server_address, - client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist).to_client(), + client=FlowerClient( + trainset=trainset, valset=valset, use_mnist=use_mnist + ).to_client(), ) diff --git a/examples/pytorch-federated-variational-autoencoder/client.py b/examples/pytorch-federated-variational-autoencoder/client.py index 65a86bcc2184..fc71f7e70c0b 100644 --- a/examples/pytorch-federated-variational-autoencoder/client.py +++ b/examples/pytorch-federated-variational-autoencoder/client.py @@ -93,7 +93,9 @@ def evaluate(self, parameters, config): loss = test(net, testloader) return float(loss), len(testloader), {} - fl.client.start_client(server_address="127.0.0.1:8080", client=CifarClient().to_client()) + fl.client.start_client( + server_address="127.0.0.1:8080", client=CifarClient().to_client() + ) if __name__ == "__main__": diff --git a/examples/quickstart-huggingface/client.py b/examples/quickstart-huggingface/client.py index 0bfc81342972..5dc461d30536 100644 --- a/examples/quickstart-huggingface/client.py +++ b/examples/quickstart-huggingface/client.py @@ -108,7 +108,9 @@ def evaluate(self, parameters, config): return float(loss), len(testloader), {"accuracy": float(accuracy)} # Start client - fl.client.start_client(server_address="127.0.0.1:8080", client=IMDBClient().to_client()) + fl.client.start_client( + server_address="127.0.0.1:8080", client=IMDBClient().to_client() + ) if __name__ == "__main__": diff --git a/examples/quickstart-jax/client.py b/examples/quickstart-jax/client.py index 0cf74e2d2c05..afd6f197bcde 100644 --- a/examples/quickstart-jax/client.py +++ b/examples/quickstart-jax/client.py @@ -52,4 +52,6 @@ def evaluate( # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=FlowerClient().to_client() +) diff --git a/examples/quickstart-mlcube/client.py b/examples/quickstart-mlcube/client.py index 0470720bc296..46ddd45f52ce 100644 --- a/examples/quickstart-mlcube/client.py +++ b/examples/quickstart-mlcube/client.py @@ -44,7 +44,8 @@ def main(): ) fl.client.start_client( - server_address="0.0.0.0:8080", client=MLCubeClient(workspace=workspace).to_client() + server_address="0.0.0.0:8080", + client=MLCubeClient(workspace=workspace).to_client(), ) diff --git a/examples/quickstart-tabnet/client.py b/examples/quickstart-tabnet/client.py index 53913b2a2a09..2289b1b55b3d 100644 --- a/examples/quickstart-tabnet/client.py +++ b/examples/quickstart-tabnet/client.py @@ -79,4 +79,6 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=TabNetClient().to_client()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=TabNetClient().to_client() +) diff --git a/examples/quickstart-tensorflow/client.py b/examples/quickstart-tensorflow/client.py index 43121f062c45..37abbbcc46ec 100644 --- a/examples/quickstart-tensorflow/client.py +++ b/examples/quickstart-tensorflow/client.py @@ -52,4 +52,6 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=CifarClient().to_client()) +fl.client.start_client( + server_address="127.0.0.1:8080", client=CifarClient().to_client() +) diff --git a/examples/sklearn-logreg-mnist/client.py b/examples/sklearn-logreg-mnist/client.py index c5a312b41e61..3d41cb6fbb21 100644 --- a/examples/sklearn-logreg-mnist/client.py +++ b/examples/sklearn-logreg-mnist/client.py @@ -62,4 +62,6 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"accuracy": accuracy} # Start Flower client - fl.client.start_client(server_address="0.0.0.0:8080", client=MnistClient().to_client()) + fl.client.start_client( + server_address="0.0.0.0:8080", client=MnistClient().to_client() + )