|
14 | 14 | import os |
15 | 15 | import time |
16 | 16 | from collections.abc import Mapping, MutableMapping |
17 | | -from typing import Any, cast |
| 17 | +from typing import Any, Callable, cast |
18 | 18 |
|
19 | 19 | import torch |
20 | 20 | import torch.distributed as dist |
@@ -359,6 +359,7 @@ def __init__( |
359 | 359 | eval_workflow_name: str = "train", |
360 | 360 | train_workflow: BundleWorkflow | None = None, |
361 | 361 | eval_workflow: BundleWorkflow | None = None, |
| 362 | + stats_sender: Callable | None = None, |
362 | 363 | ): |
363 | 364 | self.logger = logger |
364 | 365 | self.bundle_root = bundle_root |
@@ -390,6 +391,7 @@ def __init__( |
390 | 391 | if not isinstance(eval_workflow, BundleWorkflow) or eval_workflow.get_workflow_type() is None: |
391 | 392 | raise ValueError("train workflow must be BundleWorkflow and set type.") |
392 | 393 | self.eval_workflow = eval_workflow |
| 394 | + self.stats_sender = stats_sender |
393 | 395 |
|
394 | 396 | self.app_root = "" |
395 | 397 | self.filter_parser: ConfigParser | None = None |
@@ -478,6 +480,12 @@ def initialize(self, extra=None): |
478 | 480 | if len(config_filter_files) > 0: |
479 | 481 | self.filter_parser.read_config(config_filter_files) |
480 | 482 |
|
| 483 | + # set stats sender for nvflare |
| 484 | + self.stats_sender = extra.get(ExtraItems.STATS_SENDER, self.stats_sender) |
| 485 | + if self.stats_sender is not None: |
| 486 | + self.stats_sender.attach(self.trainer) |
| 487 | + self.stats_sender.attach(self.evaluator) |
| 488 | + |
481 | 489 | # Get filters |
482 | 490 | self.pre_filters = self.filter_parser.get_parsed_content( |
483 | 491 | FiltersType.PRE_FILTERS, default=ConfigItem(None, FiltersType.PRE_FILTERS) |
|
0 commit comments