Skip to content

Commit de2b48c

Browse files
vfdev-5wyliericspodpre-commit-ci[bot]
authored
Added callable options for iteration_log and epoch_log in StatsHandler (#5965)
Fixes #5964 ### Description Added callable options for iteration_log and epoch_log in StatsHandler. Ref: #5958 (reply in thread) ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: vfdev-5 <vfdev.5@gmail.com> Signed-off-by: Wenqi Li <wenqil@nvidia.com> Signed-off-by: vfdev <vfdev.5@gmail.com> Co-authored-by: Wenqi Li <wenqil@nvidia.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com>
1 parent 5657b8f commit de2b48c

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

monai/handlers/stats_handler.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class StatsHandler:
6666

6767
def __init__(
6868
self,
69-
iteration_log: bool = True,
70-
epoch_log: bool = True,
69+
iteration_log: bool | Callable[[Engine, int], bool] = True,
70+
epoch_log: bool | Callable[[Engine, int], bool] = True,
7171
epoch_print_logger: Callable[[Engine], Any] | None = None,
7272
iteration_print_logger: Callable[[Engine], Any] | None = None,
7373
output_transform: Callable = lambda x: x[0],
@@ -80,8 +80,14 @@ def __init__(
8080
"""
8181
8282
Args:
83-
iteration_log: whether to log data when iteration completed, default to `True`.
84-
epoch_log: whether to log data when epoch completed, default to `True`.
83+
iteration_log: whether to log data when iteration completed, default to `True`. ``iteration_log`` can
84+
be also a function and it will be interpreted as an event filter
85+
(see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).
86+
Event filter function accepts as input engine and event value (iteration) and should return True/False.
87+
Event filtering can be helpful to customize iteration logging frequency.
88+
epoch_log: whether to log data when epoch completed, default to `True`. ``epoch_log`` can be
89+
also a function and it will be interpreted as an event filter. See ``iteration_log`` argument for more
90+
details.
8591
epoch_print_logger: customized callable printer for epoch level logging.
8692
Must accept parameter "engine", use default printer if None.
8793
iteration_print_logger: customized callable printer for iteration level logging.
@@ -135,9 +141,15 @@ def attach(self, engine: Engine) -> None:
135141
" please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it."
136142
)
137143
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
138-
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
144+
event = Events.ITERATION_COMPLETED
145+
if callable(self.iteration_log): # substitute event with new one using filter callable
146+
event = event(event_filter=self.iteration_log)
147+
engine.add_event_handler(event, self.iteration_completed)
139148
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
140-
engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
149+
event = Events.EPOCH_COMPLETED
150+
if callable(self.epoch_log): # substitute event with new one using filter callable
151+
event = event(event_filter=self.epoch_log)
152+
engine.add_event_handler(event, self.epoch_completed)
141153
if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED):
142154
engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)
143155

tests/test_handler_stats.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,23 @@
2020

2121
import torch
2222
from ignite.engine import Engine, Events
23+
from parameterized import parameterized
2324

2425
from monai.handlers import StatsHandler
2526

2627

28+
def get_event_filter(e):
29+
def event_filter(_, event):
30+
if event in e:
31+
return True
32+
return False
33+
34+
return event_filter
35+
36+
2737
class TestHandlerStats(unittest.TestCase):
28-
def test_metrics_print(self):
38+
@parameterized.expand([[True], [get_event_filter([1, 2])]])
39+
def test_metrics_print(self, epoch_log):
2940
log_stream = StringIO()
3041
log_handler = logging.StreamHandler(log_stream)
3142
log_handler.setLevel(logging.INFO)
@@ -48,10 +59,11 @@ def _update_metric(engine):
4859
logger = logging.getLogger(key_to_handler)
4960
logger.setLevel(logging.INFO)
5061
logger.addHandler(log_handler)
51-
stats_handler = StatsHandler(iteration_log=False, epoch_log=True, name=key_to_handler)
62+
stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler)
5263
stats_handler.attach(engine)
5364

54-
engine.run(range(3), max_epochs=2)
65+
max_epochs = 4
66+
engine.run(range(3), max_epochs=max_epochs)
5567

5668
# check logging output
5769
output_str = log_stream.getvalue()
@@ -61,9 +73,13 @@ def _update_metric(engine):
6173
for line in output_str.split("\n"):
6274
if has_key_word.match(line):
6375
content_count += 1
64-
self.assertTrue(content_count > 0)
76+
if epoch_log is True:
77+
self.assertTrue(content_count == max_epochs)
78+
else:
79+
self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter
6580

66-
def test_loss_print(self):
81+
@parameterized.expand([[True], [get_event_filter([1, 3])]])
82+
def test_loss_print(self, iteration_log):
6783
log_stream = StringIO()
6884
log_handler = logging.StreamHandler(log_stream)
6985
log_handler.setLevel(logging.INFO)
@@ -80,10 +96,14 @@ def _train_func(engine, batch):
8096
logger = logging.getLogger(key_to_handler)
8197
logger.setLevel(logging.INFO)
8298
logger.addHandler(log_handler)
83-
stats_handler = StatsHandler(iteration_log=True, epoch_log=False, name=key_to_handler, tag_name=key_to_print)
99+
stats_handler = StatsHandler(
100+
iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print
101+
)
84102
stats_handler.attach(engine)
85103

86-
engine.run(range(3), max_epochs=2)
104+
num_iters = 3
105+
max_epochs = 2
106+
engine.run(range(num_iters), max_epochs=max_epochs)
87107

88108
# check logging output
89109
output_str = log_stream.getvalue()
@@ -93,7 +113,10 @@ def _train_func(engine, batch):
93113
for line in output_str.split("\n"):
94114
if has_key_word.match(line):
95115
content_count += 1
96-
self.assertTrue(content_count > 0)
116+
if iteration_log is True:
117+
self.assertTrue(content_count == num_iters * max_epochs)
118+
else:
119+
self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter
97120

98121
def test_loss_dict(self):
99122
log_stream = StringIO()

0 commit comments

Comments
 (0)