Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions keras/src/callbacks/callback_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
via `Callback.set_params`.
"""
self.callbacks = tree.flatten(callbacks) if callbacks else []
self._in_begin_end_block_count = 0
self._executor = None
self._async_train = False
self._async_test = False
Expand Down Expand Up @@ -78,9 +79,6 @@ def _configure_async_dispatch(self, callbacks):
if not utils.is_default(cbk.on_predict_batch_end):
async_predict = False

if async_train or async_test or async_predict:
self._executor = concurrent.futures.ThreadPoolExecutor()

self._async_train = async_train
self._async_test = async_test
self._async_predict = async_predict
Expand Down Expand Up @@ -113,6 +111,33 @@ def set_model(self, model):
for callback in self.callbacks:
callback.set_model(model)

def _on_begin(self):
"""Called by `on_train/test/predict_begin`.

Start the executor for async calls if needed.
"""
self._in_begin_end_block_count += 1
if (
self._in_begin_end_block_count == 1
and (self._async_train or self._async_test or self._async_predict)
and self._executor is None
):
self._executor = concurrent.futures.ThreadPoolExecutor()

def _on_end(self):
"""Called by `on_train/test/predict_end`.

Shutdown the executor for async calls if all begin/end blocks completed.
"""
self._in_begin_end_block_count -= 1
if self._in_begin_end_block_count < 0:
raise ValueError(
"`on_xxx_end` called without corresponding `on_xxx_begin`"
)
if self._in_begin_end_block_count == 0 and self._executor is not None:
self._executor.shutdown()
self._executor = None

def _async_dispatch(self, fn, *args):
for future in self._futures:
if future.done():
Expand All @@ -121,7 +146,8 @@ def _async_dispatch(self, fn, *args):
future = self._executor.submit(fn, *args)
self._futures.append(future)

def _clear_futures(self):
def _flush_futures(self):
"""Waits for all futures to complete and clears the list."""
for future in self._futures:
future.result()
self._futures = []
Expand All @@ -138,7 +164,7 @@ def on_epoch_begin(self, epoch, logs=None):

def on_epoch_end(self, epoch, logs=None):
if self._async_train:
self._clear_futures()
self._flush_futures()

logs = python_utils.pythonify_logs(logs)
for callback in self.callbacks:
Expand Down Expand Up @@ -204,44 +230,52 @@ def _on_predict_batch_end(self, batch, logs=None):
callback.on_predict_batch_end(batch, logs=logs)

def on_train_begin(self, logs=None):
self._on_begin()

logs = python_utils.pythonify_logs(logs)
for callback in self.callbacks:
callback.on_train_begin(logs)

def on_train_end(self, logs=None):
if self._async_train:
self._clear_futures()
self._flush_futures()

logs = python_utils.pythonify_logs(logs)
for callback in self.callbacks:
callback.on_train_end(logs)

self._on_end()

def on_test_begin(self, logs=None):
self._on_begin()

logs = python_utils.pythonify_logs(logs)
for callback in self.callbacks:
callback.on_test_begin(logs)

def on_test_end(self, logs=None):
if self._async_test:
self._clear_futures()
self._flush_futures()

logs = python_utils.pythonify_logs(logs)
for callback in self.callbacks:
callback.on_test_end(logs)

self._on_end()

def on_predict_begin(self, logs=None):
self._on_begin()

logs = python_utils.pythonify_logs(logs)
for callback in self.callbacks:
callback.on_predict_begin(logs)

def on_predict_end(self, logs=None):
if self._async_predict:
self._clear_futures()
self._flush_futures()

logs = python_utils.pythonify_logs(logs)
for callback in self.callbacks:
callback.on_predict_end(logs)

def __del__(self):
if self._executor is not None:
self._executor.shutdown(cancel_futures=True)
self._on_end()