Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use OperatorGroup for constrain and callback features #404

Open
wants to merge 10 commits into
base: update-0.3.3
Choose a base branch
from
Prev Previous commit
Next Next commit
use OperatorGroup to handle callback operator
  • Loading branch information
skim0119 committed Jun 29, 2024
commit d88df678ab1b3d94ab64e1afcf246ef4c39ddd20
4 changes: 3 additions & 1 deletion elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(self) -> None:
self._feature_group_constrain_rates: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_callback: list[OperatorCallbackType] = []
self._feature_group_callback: OperatorGroupFIFO[
OperatorCallbackType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_finalize: list[OperatorFinalizeType] = []
# We need to initialize our mixin classes
super().__init__()
Expand Down
35 changes: 15 additions & 20 deletions elastica/modules/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from elastica.typing import SystemType, SystemIdxType, OperatorFinalizeType
from .protocol import ModuleProtocol

import functools

import numpy as np

from elastica.callback_functions import CallBackBaseClass
Expand All @@ -29,9 +31,7 @@ class CallBacks:

def __init__(self: SystemCollectionProtocol) -> None:
self._callback_list: list[ModuleProtocol] = []
self._callback_operators: list[tuple[int, CallBackBaseClass]] = []
super(CallBacks, self).__init__()
self._feature_group_callback.append(self._callback_execution)
self._feature_group_finalize.append(self._finalize_callback)

def collect_diagnostics(
Expand All @@ -54,31 +54,26 @@ def collect_diagnostics(
sys_idx: SystemIdxType = self.get_system_index(system)

# Create _Constraint object, cache it and return to user
_callbacks: ModuleProtocol = _CallBack(sys_idx)
self._callback_list.append(_callbacks)
_callback: ModuleProtocol = _CallBack(sys_idx)
self._callback_list.append(_callback)
self._feature_group_callback.append_id(_callback)

return _callbacks
return _callback

def _finalize_callback(self: SystemCollectionProtocol) -> None:
# dev : the first index stores the rod index to collect data.
self._callback_operators = [
(callback.id(), callback.instantiate()) for callback in self._callback_list
]
for callback in self._callback_list:
sys_id = callback.id()
callback_instance = callback.instantiate()

callback_operator = functools.partial(
callback_instance.make_callback, system=self[sys_id]
)
self._feature_group_callback.add_operators(callback, [callback_operator])

self._callback_list.clear()
del self._callback_list

# First callback execution
time = np.float64(0.0)
self._callback_execution(time=time, current_step=0)

def _callback_execution(
self: SystemCollectionProtocol,
time: np.float64,
current_step: int,
) -> None:
for sys_id, callback in self._callback_operators:
callback.make_callback(self[sys_id], time, current_step)


class _CallBack:
"""
Expand Down
11 changes: 3 additions & 8 deletions elastica/modules/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def _feature_group_constrain_rates(
def constrain_rates(self, time: np.float64) -> None: ...

@property
def _feature_group_callback(self) -> list[OperatorCallbackType]: ...
def _feature_group_callback(
self,
) -> OperatorGroupFIFO[OperatorCallbackType, ModuleProtocol]: ...

def apply_callbacks(self, time: np.float64, current_step: int) -> None: ...

Expand Down Expand Up @@ -102,18 +104,11 @@ def connect(
# CallBack API
_finalize_callback: OperatorFinalizeType
_callback_list: list[ModuleProtocol]
_callback_operators: list[tuple[int, CallBackBaseClass]]

@abstractmethod
def collect_diagnostics(self, system: SystemType) -> ModuleProtocol:
raise NotImplementedError

@abstractmethod
def _callback_execution(
self, time: np.float64, current_step: int, *args: Any, **kwargs: Any
) -> None:
raise NotImplementedError

# Constraints API
_constraints_list: list[ModuleProtocol]
_finalize_constraints: OperatorFinalizeType
Expand Down