From 8bfe0854ab293c6226df66856b3d96b39dbe61fe Mon Sep 17 00:00:00 2001 From: Mirian Hipolito Garcia Date: Wed, 23 Aug 2023 15:36:25 +0000 Subject: [PATCH] Merged PR 1578: Include FedProx aggregation method Implementation of FedProx aggregation method, taken from "Federated Learning on Non-IID Data Silos: An Experimental Study" paper (https://arxiv.org/pdf/2102.02079.pdf). [x] nlg_gru_fedprox: https://ml.azure.com/runs/8c052875-d053-4e70-b5b6-8f591faf5936?wsid=/subscriptions/d4404794-ab5b-48de-b7c7-ec1fefb0a04e/resourcegroups/gcr-singularity-octo/workspaces/msroctows&tid=72f988bf-86f1-41af-91ab-2d7cd011db47 **Comparison** - DGA ( Acc 0.15, Loss 5.5) ![image.png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1578/attachments/image.png) - FedProx ( Acc 0.18, Loss 4.8) ![image (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1578/attachments/image%20%282%29.png) --- CHANGELOG.md | 5 + NOTICE.txt | 5 +- README.md | 2 + configs/hello_world_mlm_bert_json.yaml | 2 +- configs/hello_world_nlg_gru_json.yaml | 5 +- core/client.py | 10 +- core/strategies/__init__.py | 8 +- core/trainer.py | 91 ++++++++++++++++++- experiments/classif_cnn/config.yaml | 2 +- experiments/cv/config.yaml | 2 +- experiments/cv_cnn_femnist/config.yaml | 2 +- experiments/cv_lr_mnist/config.yaml | 2 +- experiments/cv_resnet_fedcifar100/config.yaml | 2 +- experiments/ecg_cnn/config.yaml | 2 +- experiments/fednewsrec/config.yaml | 2 +- .../nlp_rnn_fedshakespeare/config.yaml | 2 +- experiments/semisupervision/config.yaml | 2 +- 17 files changed, 126 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d09e64..a37f332 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,3 +47,8 @@ allowing users to run experiments using a single-GPU worker by instantiating bot and clients on the same device. For more documentation about how to run an experiments using a single GPU, please refer to the [README](README.md). + +### New features + +- 🌟 Include FedProx aggregation method + diff --git a/NOTICE.txt b/NOTICE.txt index 928b79a..4ea60a1 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -42,5 +42,6 @@ This software includes parts of Fast AutoAugment repository (https://github.com/ Code from the paper "Fast AutoAugment" (Accepted at NeurIPS 2019). This example is licenced under MIT License, you can find a copy of this licence at https://github.com/kakaobrain/fast-autoaugment/blob/master/LICENSE - - +This software includes parts of NIID-Bench repository (https://github.com/Xtra-Computing/NIID-Bench). +Code from the paper "Federated Learning on Non-IID Data Silos: An Experimental Study". This example is +licenced under MIT License, you can find a copy of this licence at https://github.com/Xtra-Computing/NIID-Bench/blob/main/LICENSE diff --git a/README.md b/README.md index 0e0967c..cb4a4a2 100644 --- a/README.md +++ b/README.md @@ -188,6 +188,8 @@ This software includes the model implementation of the FedNewsRec repository (ht For more information about third-party OSS licence, please refer to [NOTICE.txt](NOTICE.txt). This software includes the Data Augmentation scripts of the Fast AutoAugment repository (https://github.com/kakaobrain/fast-autoaugment) to preprocess the data used in the [semisupervision](experiments/semisupervision/dataloaders/cifar_dataset.py) experiment. + +This software included the FedProx logic implementation of the NIID-Bench repository (https://github.com/Xtra-Computing/NIID-Bench/tree/main) as Federated aggregation method used in the [trainer](core/trainer.py) object. ## Support You are welcome to open issues on this repository related to bug reports and feature requests. diff --git a/configs/hello_world_mlm_bert_json.yaml b/configs/hello_world_mlm_bert_json.yaml index 8787df5..cef0dd3 100644 --- a/configs/hello_world_mlm_bert_json.yaml +++ b/configs/hello_world_mlm_bert_json.yaml @@ -35,7 +35,7 @@ dp_config: privacy_metrics_config: apply_metrics: false # If enabled, the rest of parameters is needed. -# Select the Federated optimizer to use (e.g. DGA or FedAvg) +# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) strategy: DGA # Determines all the server-side settings for training and evaluation rounds diff --git a/configs/hello_world_nlg_gru_json.yaml b/configs/hello_world_nlg_gru_json.yaml index 89a6f22..3656f0c 100644 --- a/configs/hello_world_nlg_gru_json.yaml +++ b/configs/hello_world_nlg_gru_json.yaml @@ -34,8 +34,8 @@ privacy_metrics_config: # type: adamax # amsgrad: false -# Select the Federated optimizer to use (e.g. DGA or FedAvg) -strategy: DGA +# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) +strategy: FedProx # Determines all the server-side settings for training and evaluation rounds server_config: @@ -119,6 +119,7 @@ server_config: # Dictates the learning parameters for client-side model updates. Train data is defined inside this config. client_config: + mu: 0.001 # Used only for FedProx aggregation method meta_learning: basic stats_on_smooth_grad: true ignore_subtask: false diff --git a/core/client.py b/core/client.py index 9026c11..5fbadd6 100644 --- a/core/client.py +++ b/core/client.py @@ -260,7 +260,8 @@ def process_round(client_data, server_data, model, data_path, eps=1e-7): privacy_metrics_config = config.get('privacy_metrics_config', None) model_path = config["model_path"] - StrategyClass = select_strategy(config['strategy']) + strategy_algo = config['strategy'] + StrategyClass = select_strategy(strategy_algo) strategy = StrategyClass('client', config) print_rank(f'Client successfully instantiated strategy {strategy}', loglevel=logging.DEBUG) send_dicts = config['server_config'].get('send_dicts', False) @@ -358,10 +359,11 @@ def process_round(client_data, server_data, model, data_path, eps=1e-7): # This is where training actually happens algo_payload = None - if semisupervision_config != None: + if strategy_algo == 'FedLabels': datasets =[get_dataset(data_path, config, task, mode="train", test_only=False, data_strct=data_strcts[i], user_idx=0) for i in range(3)] - algo_payload = {'algo':'FedLabels', 'data': datasets, 'iter': iteration, 'config': semisupervision_config} - + algo_payload = {'strategy':'FedLabels', 'data': datasets, 'iter': iteration, 'config': semisupervision_config} + elif strategy_algo == 'FedProx': + algo_payload = {'strategy':'FedProx', 'mu': client_config.get('mu',0.001)} train_loss, num_samples, algo_computation = trainer.train_desired_samples(desired_max_samples=desired_max_samples, apply_privacy_metrics=apply_privacy_metrics, algo_payload = algo_payload) print_rank('client={}: training loss={}'.format(client_id[0], train_loss), loglevel=logging.DEBUG) diff --git a/core/strategies/__init__.py b/core/strategies/__init__.py index e56f1f0..07dd392 100644 --- a/core/strategies/__init__.py +++ b/core/strategies/__init__.py @@ -7,9 +7,15 @@ from .fedlabels import FedLabels def select_strategy(strategy): + ''' Selects the aggregation strategy class + + NOTE: FedProx uses FedAvg weights during aggregation, + which are proportional to the number of samples in + each client. + ''' if strategy.lower() == 'dga': return DGA - elif strategy.lower() == 'fedavg': + elif strategy.lower() in ['fedavg', 'fedprox']: return FedAvg elif strategy.lower() == 'fedlabels': return FedLabels diff --git a/core/trainer.py b/core/trainer.py index 5b38eea..facb909 100644 --- a/core/trainer.py +++ b/core/trainer.py @@ -328,8 +328,10 @@ def train_desired_samples(self, desired_max_samples=None, apply_privacy_metrics= if algo_payload == None: num_samples_per_epoch, train_loss_per_epoch = self.run_train_epoch(desired_max_samples, apply_privacy_metrics) - elif algo_payload['algo'] == 'FedLabels': + elif algo_payload['strategy'] == 'FedLabels': num_samples_per_epoch, train_loss_per_epoch, algo_computation = self.run_train_epoch_sup(desired_max_samples, apply_privacy_metrics, algo_payload) + elif algo_payload['strategy'] == 'FedProx': + num_samples_per_epoch, train_loss_per_epoch = self.run_train_epoch_fedprox(desired_max_samples, apply_privacy_metrics, algo_payload) num_samples += num_samples_per_epoch total_train_loss += train_loss_per_epoch @@ -411,6 +413,93 @@ def run_train_epoch(self, desired_max_samples=None, apply_privacy_metrics=False) return num_samples, sum_train_loss + def run_train_epoch_fedprox(self, desired_max_samples=None, apply_privacy_metrics=False, algo_payload=None): + """Implementation example for training the model. + + The training process should stop after the desired number of samples is processed. + + Args: + desired_max_samples (int): number of samples that you would like to process. + apply_privacy_metrics (bool): whether to save the batches used for the round for privacy metrics evaluation. + algo_payload (dict): hyperparameters needed to fine-tune FedProx algorithm. + + Returns: + 2-tuple of (int, float): number of processed samples and total training loss. + """ + + sum_train_loss = 0.0 + num_samples = 0 + self.reset_gradient_power() + + # Reset gradient just in case + self.model.zero_grad() + + # FedProx parameters + mu = algo_payload['mu'] + global_model = to_device(copy.deepcopy(self.model)) + global_weight_collector = list(global_model.parameters()) + + train_loader = self.train_dataloader.create_loader() + for batch in train_loader: + if desired_max_samples is not None and num_samples >= desired_max_samples: + break + + # Compute loss + if self.optimizer is not None: + self.optimizer.zero_grad() + + if self.ignore_subtask is True: + loss = self.model.single_task_loss(batch) + else: + if apply_privacy_metrics: + if "x" in batch: + indices = to_device(batch["x"]) + elif "input_ids" in batch: + indices = to_device(batch["input_ids"]) + self.cached_batches.append(indices) + loss = self.model.loss(batch) + + # Fedprox regularization term + fed_prox_reg = 0.0 + for param_index, param in enumerate(self.model.parameters()): + fed_prox_reg += ((mu / 2) * torch.norm((param - global_weight_collector[param_index]))**2) + loss += fed_prox_reg + loss.backward() + + # Apply gradient clipping + if self.max_grad_norm is not None: + grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + + # Sum up the gradient power + self.estimate_sufficient_stats() + + # Now that the gradients have been scaled, we can apply them + if self.optimizer is not None: + self.optimizer.step() + + print_rank("step: {}, loss: {}".format(self.step, loss.item()), loglevel=logging.DEBUG) + + # Post-processing in this loop + # Sum up the loss + sum_train_loss += loss.item() + + # Increment the number of frames processed already + if "attention_mask" in batch: + num_samples += torch.sum(batch["attention_mask"].detach().cpu() == 1).item() + elif "total_frames" in batch: + num_samples += batch["total_frames"] + else: + num_samples += len(batch["x"]) + + # Update the counters + self.step += 1 + + # Take a step in lr_scheduler + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + return num_samples, sum_train_loss + def run_train_epoch_sup(self, desired_max_samples=None, apply_privacy_metrics=False, algo_payload=None): """Implementation example for training the model using semisupervision. diff --git a/experiments/classif_cnn/config.yaml b/experiments/classif_cnn/config.yaml index 3e5f5d6..5fbb7bd 100644 --- a/experiments/classif_cnn/config.yaml +++ b/experiments/classif_cnn/config.yaml @@ -12,7 +12,7 @@ dp_config: privacy_metrics_config: apply_metrics: false # cache data to compute additional metrics -# Select the Federated optimizer to use (e.g. DGA or FedAvg) +# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) strategy: DGA # Determines all the server-side settings for training and evaluation rounds diff --git a/experiments/cv/config.yaml b/experiments/cv/config.yaml index 4469bb0..7e07822 100644 --- a/experiments/cv/config.yaml +++ b/experiments/cv/config.yaml @@ -9,7 +9,7 @@ dp_config: privacy_metrics_config: apply_metrics: false # cache data to compute additional metrics -strategy: DGA +strategy: DGA # Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) server_config: wantRL: false # whether to use RL-based meta-optimizers diff --git a/experiments/cv_cnn_femnist/config.yaml b/experiments/cv_cnn_femnist/config.yaml index 2fe777b..b9ec6d8 100644 --- a/experiments/cv_cnn_femnist/config.yaml +++ b/experiments/cv_cnn_femnist/config.yaml @@ -12,7 +12,7 @@ dp_config: privacy_metrics_config: apply_metrics: false # cache data to compute additional metrics -# Select the Federated optimizer to use (e.g. DGA or FedAvg) +# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) strategy: FedAvg # Determines all the server-side settings for training and evaluation rounds diff --git a/experiments/cv_lr_mnist/config.yaml b/experiments/cv_lr_mnist/config.yaml index 90530fd..150ecea 100644 --- a/experiments/cv_lr_mnist/config.yaml +++ b/experiments/cv_lr_mnist/config.yaml @@ -13,7 +13,7 @@ dp_config: privacy_metrics_config: apply_metrics: false # cache data to compute additional metrics -# Select the Federated optimizer to use (e.g. DGA or FedAvg) +# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) strategy: FedAvg # Determines all the server-side settings for training and evaluation rounds diff --git a/experiments/cv_resnet_fedcifar100/config.yaml b/experiments/cv_resnet_fedcifar100/config.yaml index 0ffb80b..1b3a43a 100644 --- a/experiments/cv_resnet_fedcifar100/config.yaml +++ b/experiments/cv_resnet_fedcifar100/config.yaml @@ -12,7 +12,7 @@ dp_config: privacy_metrics_config: apply_metrics: false # cache data to compute additional metrics -# Select the Federated optimizer to use (e.g. DGA or FedAvg) +# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) strategy: FedAvg # Determines all the server-side settings for training and evaluation rounds diff --git a/experiments/ecg_cnn/config.yaml b/experiments/ecg_cnn/config.yaml index a57cb09..9003add 100644 --- a/experiments/ecg_cnn/config.yaml +++ b/experiments/ecg_cnn/config.yaml @@ -12,7 +12,7 @@ dp_config: privacy_metrics_config: apply_metrics: false # cache data to compute additional metrics -# Select the Federated optimizer to use (e.g. DGA or FedAvg) +# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) strategy: DGA # Determines all the server-side settings for training and evaluation rounds diff --git a/experiments/fednewsrec/config.yaml b/experiments/fednewsrec/config.yaml index 386f8aa..f5c7e57 100644 --- a/experiments/fednewsrec/config.yaml +++ b/experiments/fednewsrec/config.yaml @@ -11,7 +11,7 @@ dp_config: privacy_metrics_config: apply_metrics: false # cache data to compute additional metrics -# Select the Federated optimizer to use (e.g. DGA or FedAvg) +# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) strategy: FedAvg # Determines all the server-side settings for training and evaluation rounds diff --git a/experiments/nlp_rnn_fedshakespeare/config.yaml b/experiments/nlp_rnn_fedshakespeare/config.yaml index 9616b22..dfaf453 100644 --- a/experiments/nlp_rnn_fedshakespeare/config.yaml +++ b/experiments/nlp_rnn_fedshakespeare/config.yaml @@ -12,7 +12,7 @@ dp_config: privacy_metrics_config: apply_metrics: false # cache data to compute additional metrics -# Select the Federated optimizer to use (e.g. DGA or FedAvg) +# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) strategy: FedAvg # Determines all the server-side settings for training and evaluation rounds diff --git a/experiments/semisupervision/config.yaml b/experiments/semisupervision/config.yaml index 3acd1b0..aa1da57 100644 --- a/experiments/semisupervision/config.yaml +++ b/experiments/semisupervision/config.yaml @@ -13,7 +13,7 @@ dp_config: privacy_metrics_config: apply_metrics: false # cache data to compute additional metrics -# Select the Federated optimizer to use (e.g. DGA or FedAvg) +# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx) strategy: FedLabels # Determines all the server-side settings for training and evaluation rounds