Example of using NVIDIA FLARE to train an image classifier using federated averaging (FedAvg) and PyTorch as the deep learning training framework. This example also highlights the TensorBoard streaming capability from the clients to the server.
NOTE: This example uses the CIFAR-10 dataset and will load its data within the trainer code.
Follow the Installation instructions. Install additional requirements:
pip3 install torch torchvision tensorboard
Follow the Quickstart instructions to set up your POC ("proof of concept") workspace.
Log into the Admin client by entering admin
for both the username and password.
Then, use these Admin commands to run the experiment:
submit_job hello-pt-tb
On the client side, the AnalyticsSender
works as a TensorBoard SummaryWriter. Instead of writing to TB files, it actually generates NVFLARE events of type analytix_log_stats
.
The ConvertToFedEvent
widget will turn the event analytix_log_stats
into a fed event fed.analytix_log_stats
, which will be delivered to the server side.
On the server side, the TBAnalyticsReceiver
is configured to process fed.analytix_log_stats
events, which writes received TB data into appropriate TB files on the server
(defaults to server/[run number]/tb_events
).
To view training metrics that are being streamed to the server, run:
tensorboard --logdir=poc/server/[run number]/tb_events
Note: if the server is running on a remote machine, use port forwarding to view the TensorBoard dashboard in a browser. For example:
ssh -L {local_machine_port}:127.0.0.1:6006 user@server_ip)
To shut down the clients and server, run the following Admin commands:
shutdown client
shutdown server
NOTE: For a more in-depth guide about the TensorBoard streaming feature, see Quickstart (PyTorch with TensorBoard).