From 3c252a819c8c785fb660c8208948b621b9aad0b9 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 29 Oct 2024 11:27:48 +0800 Subject: [PATCH] Fix the logging of a nested dictionary metric in MLflow (#8169) Fix https://github.com/Project-MONAI/model-zoo/issues/697 ### Description Flatten the metric dict when the metric is a nested dictionary. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/handlers/mlflow_handler.py | 6 ++++-- monai/handlers/stats_handler.py | 5 ++--- monai/networks/utils.py | 13 ++++++------- monai/utils/__init__.py | 1 + monai/utils/misc.py | 13 +++++++++++++ tests/test_handler_mlflow.py | 9 ++++++++- 6 files changed, 34 insertions(+), 13 deletions(-) diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index c7e293ea7d..3078d89f97 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -22,7 +22,7 @@ from torch.utils.data import Dataset from monai.apps.utils import get_logger -from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, min_version, optional_import +from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, flatten_dict, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.") @@ -303,7 +303,9 @@ def _log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None run_id = self.cur_run.info.run_id timestamp = int(time.time() * 1000) - metrics_arr = [mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in metrics.items()] + metrics_arr = [ + mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in flatten_dict(metrics).items() + ] self.client.log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[]) def _parse_artifacts(self): diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index c4971e9cac..214872fef4 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -19,7 +19,7 @@ import torch from monai.apps import get_logger -from monai.utils import IgniteInfo, is_scalar, min_version, optional_import +from monai.utils import IgniteInfo, flatten_dict, is_scalar, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: @@ -211,8 +211,7 @@ def _default_epoch_print(self, engine: Engine) -> None: """ current_epoch = self.global_epoch_transform(engine.state.epoch) - - prints_dict = engine.state.metrics + prints_dict = flatten_dict(engine.state.metrics) if prints_dict is not None and len(prints_dict) > 0: out_str = f"Epoch[{current_epoch}] Metrics -- " for name in sorted(prints_dict): diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 295a055390..cfad0364c3 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -16,6 +16,7 @@ import io import re +import tempfile import warnings from collections import OrderedDict from collections.abc import Callable, Mapping, Sequence @@ -688,16 +689,17 @@ def convert_to_onnx( onnx_inputs = (inputs,) else: onnx_inputs = tuple(inputs) - + temp_file = None if filename is None: - f = io.BytesIO() + temp_file = tempfile.NamedTemporaryFile() + f = temp_file.name else: f = filename torch.onnx.export( mode_to_export, onnx_inputs, - f=f, # type: ignore[arg-type] + f=f, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, @@ -705,10 +707,7 @@ def convert_to_onnx( do_constant_folding=do_constant_folding, **torch_versioned_kwargs, ) - if filename is None: - onnx_model = onnx.load_model_from_string(f.getvalue()) - else: - onnx_model = onnx.load(filename) + onnx_model = onnx.load(f) if do_constant_folding and polygraphy_imported: from polygraphy.backend.onnx.loader import fold_constants diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 916c1a6c70..79dc1f2304 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -78,6 +78,7 @@ ensure_tuple_size, fall_back_tuple, first, + flatten_dict, get_seed, has_option, is_immutable, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index ec9b1256a2..b96a48ad7e 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -916,3 +916,16 @@ def unsqueeze_right(arr: NT, ndim: int) -> NT: def unsqueeze_left(arr: NT, ndim: int) -> NT: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] + + +def flatten_dict(metrics: dict[str, Any]) -> dict[str, Any]: + """ + Flatten the nested dictionary to a flat dictionary. + """ + result = {} + for key, value in metrics.items(): + if isinstance(value, dict): + result.update(flatten_dict(value)) + else: + result[key] = value + return result diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index 44adc49fc2..36d59ff1bf 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -122,6 +122,11 @@ def _train_func(engine, batch): def _update_metric(engine): current_metric = engine.state.metrics.get("acc", 0.1) engine.state.metrics["acc"] = current_metric + 0.1 + # log nested metrics + engine.state.metrics["acc_per_label"] = { + "label_0": current_metric + 0.1, + "label_1": current_metric + 0.2, + } engine.state.test = current_metric # set up testing handler @@ -138,10 +143,12 @@ def _update_metric(engine): state_attributes=["test"], experiment_param=experiment_param, artifacts=[artifact_path], - close_on_complete=True, + close_on_complete=False, ) handler.attach(engine) engine.run(range(3), max_epochs=2) + cur_run = handler.client.get_run(handler.cur_run.info.run_id) + self.assertTrue("label_0" in cur_run.data.metrics.keys()) handler.close() # check logging output self.assertTrue(len(glob.glob(test_path)) > 0)