diff --git a/elastica/modules/callbacks.py b/elastica/modules/callbacks.py index 9d886926..4028b94a 100644 --- a/elastica/modules/callbacks.py +++ b/elastica/modules/callbacks.py @@ -74,6 +74,9 @@ def _finalize_callback(self: SystemCollectionProtocol) -> None: self._callback_list.clear() del self._callback_list + # First callback execution + self.apply_callbacks(time=np.float64(0.0), current_step=0) + class _CallBack: """ diff --git a/tests/test_modules/test_callbacks.py b/tests/test_modules/test_callbacks.py index d5830229..09250c17 100644 --- a/tests/test_modules/test_callbacks.py +++ b/tests/test_modules/test_callbacks.py @@ -184,3 +184,18 @@ def test_callback_finalize_sorted(self, load_rod_with_callbacks): for x, _ in scwc._callback_list: assert num < x num = x + + def test_first_call_callback_during_finalize(self, mocker, load_rod_with_callbacks): + """ + This test is to check if the callback is called during the finalize. + If this test fails, check if `apply_callbacks` is called during the finalization step. + """ + scwc, callback_cls = load_rod_with_callbacks + callback_features = [d for d in scwc._callback_list] + + spy = mocker.spy(scwc, "apply_callbacks") + scwc._finalize_callback() + + assert spy.call_count == 1 + assert spy.call_args[1]["time"] == np.float64(0.0) + assert spy.call_args[1]["current_step"] == 0