Skip to content

Commit

Permalink
Format code examples (#2877)
Browse files Browse the repository at this point in the history
Co-authored-by: Taner Topal <taner@flower.dev>
  • Loading branch information
danieljanes and tanertopal authored Feb 1, 2024
1 parent 1913d76 commit ac77ead
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 13 deletions.
4 changes: 3 additions & 1 deletion examples/custom-metrics/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
fl.client.start_client(
server_address="127.0.0.1:8080", client=FlowerClient().to_client()
)
6 changes: 4 additions & 2 deletions examples/embedded-devices/client_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
"--mnist",
action="store_true",
help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use "
"MNIST",
"MNIST",
)

warnings.filterwarnings("ignore", category=UserWarning)
NUM_CLIENTS = 50


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')."""

Expand Down Expand Up @@ -96,6 +97,7 @@ def prepare_dataset(use_mnist: bool):
img_key = "img"
norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
pytorch_transforms = Compose([ToTensor(), norm])

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch[img_key] = [pytorch_transforms(img) for img in batch[img_key]]
Expand All @@ -112,7 +114,7 @@ def apply_transforms(batch):
validsets.append(partition["test"])
testset = fds.load_full("test")
testset = testset.with_transform(apply_transforms)
return trainsets, validsets, testset
return trainsets, validsets, testset


# Flower client, adapted from Pytorch quickstart/simulation example
Expand Down
10 changes: 7 additions & 3 deletions examples/embedded-devices/client_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def prepare_dataset(use_mnist: bool):
partition.set_format("numpy")
# Divide data on each node: 90% train, 10% test
partition = partition.train_test_split(test_size=0.1)
x_train, y_train = partition["train"][img_key] / 255.0, partition["train"][
"label"]
x_train, y_train = (
partition["train"][img_key] / 255.0,
partition["train"]["label"],
)
x_test, y_test = partition["test"][img_key] / 255.0, partition["test"]["label"]
partitions.append(((x_train, y_train), (x_test, y_test)))
data_centralized = fds.load_full("test")
Expand Down Expand Up @@ -123,7 +125,9 @@ def main():
# Start Flower client setting its associated data partition
fl.client.start_client(
server_address=args.server_address,
client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist).to_client(),
client=FlowerClient(
trainset=trainset, valset=valset, use_mnist=use_mnist
).to_client(),
)


Expand Down
4 changes: 3 additions & 1 deletion examples/pytorch-federated-variational-autoencoder/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def evaluate(self, parameters, config):
loss = test(net, testloader)
return float(loss), len(testloader), {}

fl.client.start_client(server_address="127.0.0.1:8080", client=CifarClient().to_client())
fl.client.start_client(
server_address="127.0.0.1:8080", client=CifarClient().to_client()
)


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion examples/quickstart-huggingface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def evaluate(self, parameters, config):
return float(loss), len(testloader), {"accuracy": float(accuracy)}

# Start client
fl.client.start_client(server_address="127.0.0.1:8080", client=IMDBClient().to_client())
fl.client.start_client(
server_address="127.0.0.1:8080", client=IMDBClient().to_client()
)


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion examples/quickstart-jax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,6 @@ def evaluate(


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
fl.client.start_client(
server_address="127.0.0.1:8080", client=FlowerClient().to_client()
)
3 changes: 2 additions & 1 deletion examples/quickstart-mlcube/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def main():
)

fl.client.start_client(
server_address="0.0.0.0:8080", client=MLCubeClient(workspace=workspace).to_client()
server_address="0.0.0.0:8080",
client=MLCubeClient(workspace=workspace).to_client(),
)


Expand Down
4 changes: 3 additions & 1 deletion examples/quickstart-tabnet/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,6 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=TabNetClient().to_client())
fl.client.start_client(
server_address="127.0.0.1:8080", client=TabNetClient().to_client()
)
4 changes: 3 additions & 1 deletion examples/quickstart-tensorflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,6 @@ def evaluate(self, parameters, config):


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=CifarClient().to_client())
fl.client.start_client(
server_address="127.0.0.1:8080", client=CifarClient().to_client()
)
4 changes: 3 additions & 1 deletion examples/sklearn-logreg-mnist/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ def evaluate(self, parameters, config): # type: ignore
return loss, len(X_test), {"accuracy": accuracy}

# Start Flower client
fl.client.start_client(server_address="0.0.0.0:8080", client=MnistClient().to_client())
fl.client.start_client(
server_address="0.0.0.0:8080", client=MnistClient().to_client()
)

0 comments on commit ac77ead

Please sign in to comment.