Skip to content

Commit

Permalink
Fix the logging of a nested dictionary metric in MLflow (#8169)
Browse files Browse the repository at this point in the history
Fix Project-MONAI/model-zoo#697

### Description
Flatten the metric dict when the metric is a nested dictionary.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
KumoLiu and pre-commit-ci[bot] authored Oct 29, 2024
1 parent 82298ad commit 3c252a8
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 13 deletions.
6 changes: 4 additions & 2 deletions monai/handlers/mlflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.utils.data import Dataset

from monai.apps.utils import get_logger
from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, min_version, optional_import
from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, flatten_dict, min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.")
Expand Down Expand Up @@ -303,7 +303,9 @@ def _log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None

run_id = self.cur_run.info.run_id
timestamp = int(time.time() * 1000)
metrics_arr = [mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in metrics.items()]
metrics_arr = [
mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in flatten_dict(metrics).items()
]
self.client.log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[])

def _parse_artifacts(self):
Expand Down
5 changes: 2 additions & 3 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch

from monai.apps import get_logger
from monai.utils import IgniteInfo, is_scalar, min_version, optional_import
from monai.utils import IgniteInfo, flatten_dict, is_scalar, min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
Expand Down Expand Up @@ -211,8 +211,7 @@ def _default_epoch_print(self, engine: Engine) -> None:
"""
current_epoch = self.global_epoch_transform(engine.state.epoch)

prints_dict = engine.state.metrics
prints_dict = flatten_dict(engine.state.metrics)
if prints_dict is not None and len(prints_dict) > 0:
out_str = f"Epoch[{current_epoch}] Metrics -- "
for name in sorted(prints_dict):
Expand Down
13 changes: 6 additions & 7 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import io
import re
import tempfile
import warnings
from collections import OrderedDict
from collections.abc import Callable, Mapping, Sequence
Expand Down Expand Up @@ -688,27 +689,25 @@ def convert_to_onnx(
onnx_inputs = (inputs,)
else:
onnx_inputs = tuple(inputs)

temp_file = None
if filename is None:
f = io.BytesIO()
temp_file = tempfile.NamedTemporaryFile()
f = temp_file.name
else:
f = filename

torch.onnx.export(
mode_to_export,
onnx_inputs,
f=f, # type: ignore[arg-type]
f=f,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
**torch_versioned_kwargs,
)
if filename is None:
onnx_model = onnx.load_model_from_string(f.getvalue())
else:
onnx_model = onnx.load(filename)
onnx_model = onnx.load(f)

if do_constant_folding and polygraphy_imported:
from polygraphy.backend.onnx.loader import fold_constants
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
ensure_tuple_size,
fall_back_tuple,
first,
flatten_dict,
get_seed,
has_option,
is_immutable,
Expand Down
13 changes: 13 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,3 +916,16 @@ def unsqueeze_right(arr: NT, ndim: int) -> NT:
def unsqueeze_left(arr: NT, ndim: int) -> NT:
"""Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
return arr[(None,) * (ndim - arr.ndim)]


def flatten_dict(metrics: dict[str, Any]) -> dict[str, Any]:
"""
Flatten the nested dictionary to a flat dictionary.
"""
result = {}
for key, value in metrics.items():
if isinstance(value, dict):
result.update(flatten_dict(value))
else:
result[key] = value
return result
9 changes: 8 additions & 1 deletion tests/test_handler_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def _train_func(engine, batch):
def _update_metric(engine):
current_metric = engine.state.metrics.get("acc", 0.1)
engine.state.metrics["acc"] = current_metric + 0.1
# log nested metrics
engine.state.metrics["acc_per_label"] = {
"label_0": current_metric + 0.1,
"label_1": current_metric + 0.2,
}
engine.state.test = current_metric

# set up testing handler
Expand All @@ -138,10 +143,12 @@ def _update_metric(engine):
state_attributes=["test"],
experiment_param=experiment_param,
artifacts=[artifact_path],
close_on_complete=True,
close_on_complete=False,
)
handler.attach(engine)
engine.run(range(3), max_epochs=2)
cur_run = handler.client.get_run(handler.cur_run.info.run_id)
self.assertTrue("label_0" in cur_run.data.metrics.keys())
handler.close()
# check logging output
self.assertTrue(len(glob.glob(test_path)) > 0)
Expand Down

0 comments on commit 3c252a8

Please sign in to comment.