From 8810c4faf1491142f360d0a510c6e9f2607e8454 Mon Sep 17 00:00:00 2001 From: Seung Hyun Kim Date: Sun, 30 Jun 2024 21:04:52 -0500 Subject: [PATCH] fix: first callback execution --- elastica/modules/callbacks.py | 3 +++ tests/test_modules/test_callbacks.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/elastica/modules/callbacks.py b/elastica/modules/callbacks.py index 9d886926..4028b94a 100644 --- a/elastica/modules/callbacks.py +++ b/elastica/modules/callbacks.py @@ -74,6 +74,9 @@ def _finalize_callback(self: SystemCollectionProtocol) -> None: self._callback_list.clear() del self._callback_list + # First callback execution + self.apply_callbacks(time=np.float64(0.0), current_step=0) + class _CallBack: """ diff --git a/tests/test_modules/test_callbacks.py b/tests/test_modules/test_callbacks.py index d5830229..da786b75 100644 --- a/tests/test_modules/test_callbacks.py +++ b/tests/test_modules/test_callbacks.py @@ -184,3 +184,15 @@ def test_callback_finalize_sorted(self, load_rod_with_callbacks): for x, _ in scwc._callback_list: assert num < x num = x + + def test_first_call_callback_during_finalize(self, mocker, load_rod_with_callbacks): + scwc, callback_cls = load_rod_with_callbacks + callback_features = [d for d in scwc._callback_list] + + spy = mocker.spy(scwc, "apply_callbacks") + scwc._finalize_callback() + + assert spy.call_count == 1 + breakpoint() + assert spy.call_args[1]["time"] == np.float64(0.0) + assert spy.call_args[1]["current_step"] == 0