diff --git a/clu/periodic_actions.py b/clu/periodic_actions.py index 086734b..673122c 100644 --- a/clu/periodic_actions.py +++ b/clu/periodic_actions.py @@ -18,6 +18,7 @@ import collections import concurrent.futures import contextlib +import functools import os import time from typing import Callable, Iterable, Optional, Sequence @@ -410,15 +411,18 @@ def _apply(self, step: int, t: float): def _start_session(self): profiler.collect( logdir=self._logdir, - callback=self._end_session, + # Callback is executed asynchronously, so bind `self._previous_step` + callback=functools.partial(self._end_session, step=self._previous_step), hosts=self._hosts, - duration_ms=self._profile_duration_ms) + duration_ms=self._profile_duration_ms, + ) - def _end_session(self, url: Optional[str]): + def _end_session(self, url: Optional[str], *, step: int): platform.work_unit().create_artifact( platform.ArtifactType.URL, url, - description=f"[{self._previous_step}] Profile") + description=f"[{step}] Profile", + ) class PeriodicCallback(PeriodicAction):