Skip to content

Commit

Permalink
Merged PR 1578: Include FedProx aggregation method
Browse files Browse the repository at this point in the history
  • Loading branch information
Mirian-Hipolito committed Aug 23, 2023
1 parent e8fe10b commit 8bfe085
Show file tree
Hide file tree
Showing 17 changed files with 126 additions and 20 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ allowing users to run experiments using a single-GPU worker by instantiating bot
and clients on the same device. For more documentation about how to run an experiments
using a single GPU, please refer to the [README](README.md).


### New features

- 🌟 Include FedProx aggregation method

5 changes: 3 additions & 2 deletions NOTICE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,6 @@ This software includes parts of Fast AutoAugment repository (https://github.com/
Code from the paper "Fast AutoAugment" (Accepted at NeurIPS 2019). This example is licenced
under MIT License, you can find a copy of this licence at https://github.com/kakaobrain/fast-autoaugment/blob/master/LICENSE



This software includes parts of NIID-Bench repository (https://github.com/Xtra-Computing/NIID-Bench).
Code from the paper "Federated Learning on Non-IID Data Silos: An Experimental Study". This example is
licenced under MIT License, you can find a copy of this licence at https://github.com/Xtra-Computing/NIID-Bench/blob/main/LICENSE
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ This software includes the model implementation of the FedNewsRec repository (ht
For more information about third-party OSS licence, please refer to [NOTICE.txt](NOTICE.txt).

This software includes the Data Augmentation scripts of the Fast AutoAugment repository (https://github.com/kakaobrain/fast-autoaugment) to preprocess the data used in the [semisupervision](experiments/semisupervision/dataloaders/cifar_dataset.py) experiment.

This software included the FedProx logic implementation of the NIID-Bench repository (https://github.com/Xtra-Computing/NIID-Bench/tree/main) as Federated aggregation method used in the [trainer](core/trainer.py) object.
## Support

You are welcome to open issues on this repository related to bug reports and feature requests.
Expand Down
2 changes: 1 addition & 1 deletion configs/hello_world_mlm_bert_json.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dp_config:
privacy_metrics_config:
apply_metrics: false # If enabled, the rest of parameters is needed.

# Select the Federated optimizer to use (e.g. DGA or FedAvg)
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: DGA

# Determines all the server-side settings for training and evaluation rounds
Expand Down
5 changes: 3 additions & 2 deletions configs/hello_world_nlg_gru_json.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ privacy_metrics_config:
# type: adamax
# amsgrad: false

# Select the Federated optimizer to use (e.g. DGA or FedAvg)
strategy: DGA
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: FedProx

# Determines all the server-side settings for training and evaluation rounds
server_config:
Expand Down Expand Up @@ -119,6 +119,7 @@ server_config:

# Dictates the learning parameters for client-side model updates. Train data is defined inside this config.
client_config:
mu: 0.001 # Used only for FedProx aggregation method
meta_learning: basic
stats_on_smooth_grad: true
ignore_subtask: false
Expand Down
10 changes: 6 additions & 4 deletions core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ def process_round(client_data, server_data, model, data_path, eps=1e-7):
privacy_metrics_config = config.get('privacy_metrics_config', None)
model_path = config["model_path"]

StrategyClass = select_strategy(config['strategy'])
strategy_algo = config['strategy']
StrategyClass = select_strategy(strategy_algo)
strategy = StrategyClass('client', config)
print_rank(f'Client successfully instantiated strategy {strategy}', loglevel=logging.DEBUG)
send_dicts = config['server_config'].get('send_dicts', False)
Expand Down Expand Up @@ -358,10 +359,11 @@ def process_round(client_data, server_data, model, data_path, eps=1e-7):
# This is where training actually happens
algo_payload = None

if semisupervision_config != None:
if strategy_algo == 'FedLabels':
datasets =[get_dataset(data_path, config, task, mode="train", test_only=False, data_strct=data_strcts[i], user_idx=0) for i in range(3)]
algo_payload = {'algo':'FedLabels', 'data': datasets, 'iter': iteration, 'config': semisupervision_config}

algo_payload = {'strategy':'FedLabels', 'data': datasets, 'iter': iteration, 'config': semisupervision_config}
elif strategy_algo == 'FedProx':
algo_payload = {'strategy':'FedProx', 'mu': client_config.get('mu',0.001)}
train_loss, num_samples, algo_computation = trainer.train_desired_samples(desired_max_samples=desired_max_samples, apply_privacy_metrics=apply_privacy_metrics, algo_payload = algo_payload)
print_rank('client={}: training loss={}'.format(client_id[0], train_loss), loglevel=logging.DEBUG)

Expand Down
8 changes: 7 additions & 1 deletion core/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
from .fedlabels import FedLabels

def select_strategy(strategy):
''' Selects the aggregation strategy class
NOTE: FedProx uses FedAvg weights during aggregation,
which are proportional to the number of samples in
each client.
'''
if strategy.lower() == 'dga':
return DGA
elif strategy.lower() == 'fedavg':
elif strategy.lower() in ['fedavg', 'fedprox']:
return FedAvg
elif strategy.lower() == 'fedlabels':
return FedLabels
Expand Down
91 changes: 90 additions & 1 deletion core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,10 @@ def train_desired_samples(self, desired_max_samples=None, apply_privacy_metrics=

if algo_payload == None:
num_samples_per_epoch, train_loss_per_epoch = self.run_train_epoch(desired_max_samples, apply_privacy_metrics)
elif algo_payload['algo'] == 'FedLabels':
elif algo_payload['strategy'] == 'FedLabels':
num_samples_per_epoch, train_loss_per_epoch, algo_computation = self.run_train_epoch_sup(desired_max_samples, apply_privacy_metrics, algo_payload)
elif algo_payload['strategy'] == 'FedProx':
num_samples_per_epoch, train_loss_per_epoch = self.run_train_epoch_fedprox(desired_max_samples, apply_privacy_metrics, algo_payload)

num_samples += num_samples_per_epoch
total_train_loss += train_loss_per_epoch
Expand Down Expand Up @@ -411,6 +413,93 @@ def run_train_epoch(self, desired_max_samples=None, apply_privacy_metrics=False)

return num_samples, sum_train_loss

def run_train_epoch_fedprox(self, desired_max_samples=None, apply_privacy_metrics=False, algo_payload=None):
"""Implementation example for training the model.
The training process should stop after the desired number of samples is processed.
Args:
desired_max_samples (int): number of samples that you would like to process.
apply_privacy_metrics (bool): whether to save the batches used for the round for privacy metrics evaluation.
algo_payload (dict): hyperparameters needed to fine-tune FedProx algorithm.
Returns:
2-tuple of (int, float): number of processed samples and total training loss.
"""

sum_train_loss = 0.0
num_samples = 0
self.reset_gradient_power()

# Reset gradient just in case
self.model.zero_grad()

# FedProx parameters
mu = algo_payload['mu']
global_model = to_device(copy.deepcopy(self.model))
global_weight_collector = list(global_model.parameters())

train_loader = self.train_dataloader.create_loader()
for batch in train_loader:
if desired_max_samples is not None and num_samples >= desired_max_samples:
break

# Compute loss
if self.optimizer is not None:
self.optimizer.zero_grad()

if self.ignore_subtask is True:
loss = self.model.single_task_loss(batch)
else:
if apply_privacy_metrics:
if "x" in batch:
indices = to_device(batch["x"])
elif "input_ids" in batch:
indices = to_device(batch["input_ids"])
self.cached_batches.append(indices)
loss = self.model.loss(batch)

# Fedprox regularization term
fed_prox_reg = 0.0
for param_index, param in enumerate(self.model.parameters()):
fed_prox_reg += ((mu / 2) * torch.norm((param - global_weight_collector[param_index]))**2)
loss += fed_prox_reg
loss.backward()

# Apply gradient clipping
if self.max_grad_norm is not None:
grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

# Sum up the gradient power
self.estimate_sufficient_stats()

# Now that the gradients have been scaled, we can apply them
if self.optimizer is not None:
self.optimizer.step()

print_rank("step: {}, loss: {}".format(self.step, loss.item()), loglevel=logging.DEBUG)

# Post-processing in this loop
# Sum up the loss
sum_train_loss += loss.item()

# Increment the number of frames processed already
if "attention_mask" in batch:
num_samples += torch.sum(batch["attention_mask"].detach().cpu() == 1).item()
elif "total_frames" in batch:
num_samples += batch["total_frames"]
else:
num_samples += len(batch["x"])

# Update the counters
self.step += 1

# Take a step in lr_scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return num_samples, sum_train_loss

def run_train_epoch_sup(self, desired_max_samples=None, apply_privacy_metrics=False, algo_payload=None):
"""Implementation example for training the model using semisupervision.
Expand Down
2 changes: 1 addition & 1 deletion experiments/classif_cnn/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dp_config:
privacy_metrics_config:
apply_metrics: false # cache data to compute additional metrics

# Select the Federated optimizer to use (e.g. DGA or FedAvg)
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: DGA

# Determines all the server-side settings for training and evaluation rounds
Expand Down
2 changes: 1 addition & 1 deletion experiments/cv/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dp_config:
privacy_metrics_config:
apply_metrics: false # cache data to compute additional metrics

strategy: DGA
strategy: DGA # Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)

server_config:
wantRL: false # whether to use RL-based meta-optimizers
Expand Down
2 changes: 1 addition & 1 deletion experiments/cv_cnn_femnist/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dp_config:
privacy_metrics_config:
apply_metrics: false # cache data to compute additional metrics

# Select the Federated optimizer to use (e.g. DGA or FedAvg)
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: FedAvg

# Determines all the server-side settings for training and evaluation rounds
Expand Down
2 changes: 1 addition & 1 deletion experiments/cv_lr_mnist/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dp_config:
privacy_metrics_config:
apply_metrics: false # cache data to compute additional metrics

# Select the Federated optimizer to use (e.g. DGA or FedAvg)
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: FedAvg

# Determines all the server-side settings for training and evaluation rounds
Expand Down
2 changes: 1 addition & 1 deletion experiments/cv_resnet_fedcifar100/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dp_config:
privacy_metrics_config:
apply_metrics: false # cache data to compute additional metrics

# Select the Federated optimizer to use (e.g. DGA or FedAvg)
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: FedAvg

# Determines all the server-side settings for training and evaluation rounds
Expand Down
2 changes: 1 addition & 1 deletion experiments/ecg_cnn/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dp_config:
privacy_metrics_config:
apply_metrics: false # cache data to compute additional metrics

# Select the Federated optimizer to use (e.g. DGA or FedAvg)
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: DGA

# Determines all the server-side settings for training and evaluation rounds
Expand Down
2 changes: 1 addition & 1 deletion experiments/fednewsrec/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dp_config:
privacy_metrics_config:
apply_metrics: false # cache data to compute additional metrics

# Select the Federated optimizer to use (e.g. DGA or FedAvg)
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: FedAvg

# Determines all the server-side settings for training and evaluation rounds
Expand Down
2 changes: 1 addition & 1 deletion experiments/nlp_rnn_fedshakespeare/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dp_config:
privacy_metrics_config:
apply_metrics: false # cache data to compute additional metrics

# Select the Federated optimizer to use (e.g. DGA or FedAvg)
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: FedAvg

# Determines all the server-side settings for training and evaluation rounds
Expand Down
2 changes: 1 addition & 1 deletion experiments/semisupervision/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dp_config:
privacy_metrics_config:
apply_metrics: false # cache data to compute additional metrics

# Select the Federated optimizer to use (e.g. DGA or FedAvg)
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: FedLabels

# Determines all the server-side settings for training and evaluation rounds
Expand Down

0 comments on commit 8bfe085

Please sign in to comment.