Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve federated learning rounds per epoch #41

Merged
merged 2 commits into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 43 additions & 35 deletions fltk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,45 +57,53 @@ def _event_loop(self):
time.sleep(0.1)
self.logger.info('Exiting node')

def train(self, num_epochs: int):
def train(self, num_epochs: int, round_id: int):
"""
Function implementing federated learning training loop.
@param num_epochs: Number of epochs to run.
Function implementing federated learning training loop, allowing to run for a configurable number of epochs
on a local dataset. Note that only the last statistics of a run are sent to the caller (i.e. Federator).
@param num_epochs: Number of epochs to run during a communication round's training loop.
@type num_epochs: int
@return: Final running loss statistic and acquired parameters of the locally trained network.
@param round_id: Global communication round ID to be used during training.
@type round_id: int
@return: Final running loss statistic and acquired parameters of the locally trained network. NOTE that
intermediate information is only logged to the STD-out.
@rtype: Tuple[float, Dict[str, torch.Tensor]]
"""
start_time = time.time()

running_loss = 0.0
final_running_loss = 0.0
if self.distributed:
self.dataset.train_sampler.set_epoch(num_epochs)

number_of_training_samples = len(self.dataset.get_train_loader())
self.logger.info(f'{self.id}: Number of training samples: {number_of_training_samples}')

for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0):
inputs, labels = inputs.to(self.device), labels.to(self.device)

# zero the parameter gradients
self.optimizer.zero_grad()

outputs = self.net(inputs)
loss = self.loss_function(outputs, labels)

loss.backward()
self.optimizer.step()
running_loss += loss.item()
# Mark logging update step
if i % self.config.log_interval == 0:
self.logger.info(
f'[{self.id}] [{num_epochs:d}, {i:5d}] loss: {running_loss / self.config.log_interval:.3f}')
final_running_loss = running_loss / self.config.log_interval
running_loss = 0.0
end_time = time.time()
duration = end_time - start_time
self.logger.info(f'Train duration is {duration} seconds')
for local_epoch in range(num_epochs):
effective_epoch = round_id * num_epochs + local_epoch
JMGaljaard marked this conversation as resolved.
Show resolved Hide resolved
progress = f'[RD-{round_id}][LE-{local_epoch}][EE-{effective_epoch}]'
if self.distributed:
# In case a client occurs within (num_epochs) communication rounds as this would cause
# an order or data to re-occur during training.
self.dataset.train_sampler.set_epoch(effective_epoch)

training_cardinality = len(self.dataset.get_train_loader())
self.logger.info(f'{progress}{self.id}: Number of training samples: {training_cardinality}')

for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0):
inputs, labels = inputs.to(self.device), labels.to(self.device)

# zero the parameter gradients
self.optimizer.zero_grad()

outputs = self.net(inputs)
loss = self.loss_function(outputs, labels)

loss.backward()
self.optimizer.step()
running_loss += loss.item()
# Mark logging update step
if i % self.config.log_interval == 0:
self.logger.info(
f'[{self.id}] [{local_epoch}/{num_epochs:d}, {i:5d}] loss: {running_loss / self.config.log_interval:.3f}')
final_running_loss = running_loss / self.config.log_interval
running_loss = 0.0
end_time = time.time()
duration = end_time - start_time
self.logger.info(f'{progress} Train duration is {duration} seconds')

return final_running_loss, self.get_nn_parameters(),

Expand Down Expand Up @@ -148,7 +156,7 @@ def test(self) -> Tuple[float, float, np.array]:
def get_client_datasize(self): # pylint: disable=missing-function-docstring
return len(self.dataset.get_train_sampler())

def exec_round(self, num_epochs: int) -> Tuple[Any, Any, Any, Any, float, float, float, np.array]:
def exec_round(self, num_epochs: int, round_id: int) -> Tuple[Any, Any, Any, Any, float, float, float, np.array]:
"""
Function as access point for the Federator Node to kick off a remote learning round on a client.
@param num_epochs: Number of epochs to run
Expand All @@ -157,9 +165,9 @@ def exec_round(self, num_epochs: int) -> Tuple[Any, Any, Any, Any, float, float,
training make-span, testing make-span, and confusion matrix.
@rtype: Tuple[Any, Any, Any, Any, float, float, float, np.array]
"""
self.logger.info(f"[EXEC] running {num_epochs} locally...")
start = time.time()

loss, weights = self.train(num_epochs)
loss, weights = self.train(num_epochs, round_id)
time_mark_between = time.time()
accuracy, test_loss, test_conf_matrix = self.test()

Expand Down
2 changes: 1 addition & 1 deletion fltk/core/federator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def training_cb(fut: torch.Future, client_ref: LocalClient, client_weights, clie
client_ref.exp_data.append(c_record)

for client in selected_clients:
future = self.message_async(client.ref, Client.exec_round, num_epochs)
future = self.message_async(client.ref, Client.exec_round, num_epochs, com_round_id)
cb_factory(future, training_cb, client, client_weights, client_sizes, num_epochs)
self.logger.info(f'Request sent to client {client.name}')
training_futures.append(future)
Expand Down
10 changes: 9 additions & 1 deletion fltk/util/task/config/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# noinspection PyUnresolvedReferences
from typing import List, Optional, OrderedDict, Any, Union, Tuple, Type, Dict, MutableMapping, T

import deprecate
from dataclasses_json import dataclass_json, LetterCase, config
# noinspection PyProtectedMember
from torch.nn.modules.loss import _Loss
Expand Down Expand Up @@ -224,14 +225,21 @@ class LearningParameters:
Dataclass containing configuration parameters for the learning process itself. This includes the Federated learning
parameters as well as some system parameters like cuda.
"""
total_epochs: int
_total_epochs: int = field(metadata=config(field_name='totalEpochs'))
cuda: bool
rounds: Optional[int] = None
epochs_per_round: Optional[int] = None
clients_per_round: Optional[int] = None
aggregation: Optional[Aggregations] = None
data_sampler: Optional[SamplerConfiguration] = None

@property
def total_epochs(self):
logging.warning('By default `total_epochs` is not used duruing Federated Learning. This attribute will be'
'changed in a coming release.')
return self._total_epochs


@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass(frozen=True)
class ExperimentConfiguration:
Expand Down
2 changes: 1 addition & 1 deletion tests/core/client_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ def test_fed_client(self, name, net: Nets, dataset: Dataset):
self.assertTrue(fed_client.is_ready())

with patch.object(DS, 'get_train_dataset', fed_client.dataset):
self.assertTrue(fed_client.exec_round(1))
self.assertTrue(fed_client.exec_round(1, 0))