Skip to content

Commit 5657b8f

Browse files
authored
5762 pprint head and tail bundle script (#5969)
Signed-off-by: Wenqi Li <wenqil@nvidia.com> Fixes #5762 ### Description limiting the number of printing lines ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 94feae5 commit 5657b8f

File tree

4 files changed

+32
-3
lines changed

4 files changed

+32
-3
lines changed

monai/bundle/scripts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import ast
1515
import json
1616
import os
17-
import pprint
1817
import re
1918
import time
2019
import warnings
@@ -37,7 +36,7 @@
3736
from monai.data import load_net_with_metadata, save_net_with_metadata
3837
from monai.networks import convert_to_torchscript, copy_model_state, get_state_dict, save_state
3938
from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import
40-
from monai.utils.misc import ensure_tuple
39+
from monai.utils.misc import ensure_tuple, pprint_edges
4140

4241
validate, _ = optional_import("jsonschema", name="validate")
4342
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
@@ -48,6 +47,7 @@
4847

4948
# set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download
5049
download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github")
50+
PPRINT_CONFIG_N = 5
5151

5252

5353
def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
@@ -88,7 +88,7 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple:
8888
def _log_input_summary(tag: str, args: dict) -> None:
8989
logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---")
9090
for name, val in args.items():
91-
logger.info(f"> {name}: {pprint.pformat(val)}")
91+
logger.info(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}")
9292
logger.info("---\n\n")
9393

9494

monai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
issequenceiterable,
7878
list_to_dict,
7979
path_to_uri,
80+
pprint_edges,
8081
progress_bar,
8182
sample_slices,
8283
save_obj,

monai/utils/misc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import inspect
1515
import itertools
1616
import os
17+
import pprint
1718
import random
1819
import shutil
1920
import tempfile
@@ -60,6 +61,7 @@
6061
"save_obj",
6162
"label_union",
6263
"path_to_uri",
64+
"pprint_edges",
6365
]
6466

6567
_seed = None
@@ -626,3 +628,17 @@ def path_to_uri(path: PathLike) -> str:
626628
627629
"""
628630
return Path(path).absolute().as_uri()
631+
632+
633+
def pprint_edges(val: Any, n_lines: int = 20) -> str:
634+
"""
635+
Pretty print the head and tail ``n_lines`` of ``val``, and omit the middle part if the part has more than 3 lines.
636+
637+
Returns: the formatted string.
638+
"""
639+
val_str = pprint.pformat(val).splitlines(True)
640+
n_lines = max(n_lines, 1)
641+
if len(val_str) > n_lines * 2 + 3:
642+
hidden_n = len(val_str) - n_lines * 2
643+
val_str = val_str[:n_lines] + [f"\n ... omitted {hidden_n} line(s)\n\n"] + val_str[-n_lines:]
644+
return "".join(val_str)

tests/test_bundle_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from monai.bundle.utils import load_bundle_config
2222
from monai.networks.nets import UNet
23+
from monai.utils import pprint_edges
2324
from tests.utils import command_line_tests, skip_if_windows
2425

2526
metadata = """
@@ -117,5 +118,16 @@ def test_load_config_ts(self):
117118
self.assertEqual(p["test_dict"]["b"], "c")
118119

119120

121+
class TestPPrintEdges(unittest.TestCase):
122+
def test_str(self):
123+
self.assertEqual(pprint_edges("", 0), "''")
124+
self.assertEqual(pprint_edges({"a": 1, "b": 2}, 0), "{'a': 1, 'b': 2}")
125+
self.assertEqual(
126+
pprint_edges([{"a": 1, "b": 2}] * 20, 1),
127+
"[{'a': 1, 'b': 2},\n\n ... omitted 18 line(s)\n\n {'a': 1, 'b': 2}]",
128+
)
129+
self.assertEqual(pprint_edges([{"a": 1, "b": 2}] * 8, 4), pprint_edges([{"a": 1, "b": 2}] * 8, 3))
130+
131+
120132
if __name__ == "__main__":
121133
unittest.main()

0 commit comments

Comments
 (0)