Skip to content

Commit 98b2c94

Browse files
committed
Add stats_sender to MonaiAlgo for FL stats
Signed-off-by: Kevin <kevlu@nvidia.com>
1 parent c22a2bd commit 98b2c94

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

monai/fl/client/monai_algo.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
import time
1616
from collections.abc import Mapping, MutableMapping
17-
from typing import Any, cast
17+
from typing import Any, Callable, cast
1818

1919
import torch
2020
import torch.distributed as dist
@@ -359,6 +359,7 @@ def __init__(
359359
eval_workflow_name: str = "train",
360360
train_workflow: BundleWorkflow | None = None,
361361
eval_workflow: BundleWorkflow | None = None,
362+
stats_sender: Callable | None = None,
362363
):
363364
self.logger = logger
364365
self.bundle_root = bundle_root
@@ -390,6 +391,7 @@ def __init__(
390391
if not isinstance(eval_workflow, BundleWorkflow) or eval_workflow.get_workflow_type() is None:
391392
raise ValueError("train workflow must be BundleWorkflow and set type.")
392393
self.eval_workflow = eval_workflow
394+
self.stats_sender = stats_sender
393395

394396
self.app_root = ""
395397
self.filter_parser: ConfigParser | None = None
@@ -478,6 +480,12 @@ def initialize(self, extra=None):
478480
if len(config_filter_files) > 0:
479481
self.filter_parser.read_config(config_filter_files)
480482

483+
# set stats sender for nvflare
484+
self.stats_sender = extra.get(ExtraItems.STATS_SENDER, self.stats_sender)
485+
if self.stats_sender is not None:
486+
self.stats_sender.attach(self.trainer)
487+
self.stats_sender.attach(self.evaluator)
488+
481489
# Get filters
482490
self.pre_filters = self.filter_parser.get_parsed_content(
483491
FiltersType.PRE_FILTERS, default=ConfigItem(None, FiltersType.PRE_FILTERS)

monai/fl/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ExtraItems(StrEnum):
2929
MODEL_TYPE = "fl_model_type"
3030
CLIENT_NAME = "fl_client_name"
3131
APP_ROOT = "fl_app_root"
32+
STATS_SENDER = "fl_stats_sender"
3233

3334

3435
class FlPhase(StrEnum):

0 commit comments

Comments
 (0)