From 36d93082970365b7c5ae7022aad9f9a7308ceb82 Mon Sep 17 00:00:00 2001 From: Alec Koumjian Date: Fri, 24 Jan 2025 11:09:37 -0500 Subject: [PATCH] Use __getstate__ and __setstate__ for propagator serialization (#139) * Use __getstate__ and __setstate__ for propagator serialization * Give helpful error messages --- pyproject.toml | 2 +- src/adam_core/dynamics/tests/test_impacts.py | 7 ++ src/adam_core/propagator/propagator.py | 66 ++++++++++++++----- .../propagator/tests/test_propagator.py | 8 +++ 4 files changed, 66 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c93ce519..e546bd64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] description = "Core libraries for the ADAM platform" readme = "README.md" -requires-python = ">=3.10,<3.13" +requires-python = ">=3.11,<3.13" classifiers = [ "Operating System :: OS Independent", "Development Status :: 4 - Beta", diff --git a/src/adam_core/dynamics/tests/test_impacts.py b/src/adam_core/dynamics/tests/test_impacts.py index a32b847a..d0fe8cdd 100644 --- a/src/adam_core/dynamics/tests/test_impacts.py +++ b/src/adam_core/dynamics/tests/test_impacts.py @@ -16,6 +16,13 @@ class MockImpactPropagator(Propagator, ImpactMixin): + def __getstate__(self): + state = self.__dict__.copy() + return state + + def __setstate__(self, state): + self.__dict__.update(state) + def _propagate_orbits(self, orbits: Orbits, times: Timestamp) -> Orbits: return orbits diff --git a/src/adam_core/propagator/propagator.py b/src/adam_core/propagator/propagator.py index 10efce04..76362db7 100644 --- a/src/adam_core/propagator/propagator.py +++ b/src/adam_core/propagator/propagator.py @@ -69,12 +69,10 @@ def propagation_worker_ray( idx: npt.NDArray[np.int64], orbits: OrbitType, times: OrbitType, - propagator: Type["Propagator"], - **kwargs, + propagator: "Propagator", ) -> OrbitType: - prop = propagator(**kwargs) orbits_chunk = orbits.take(idx) - propagated = prop._propagate_orbits(orbits_chunk, times) + propagated = propagator._propagate_orbits(orbits_chunk, times) return propagated @ray.remote @@ -82,12 +80,10 @@ def ephemeris_worker_ray( idx: npt.NDArray[np.int64], orbits: OrbitType, observers: ObserverType, - propagator: Type["Propagator"], - **kwargs, + propagator: "Propagator", ) -> EphemerisType: - prop = propagator(**kwargs) orbits_chunk = orbits.take(idx) - ephemeris = prop._generate_ephemeris(orbits_chunk, observers) + ephemeris = propagator._generate_ephemeris(orbits_chunk, observers) return ephemeris @@ -369,8 +365,7 @@ def generate_ephemeris( idx_chunk, orbits_ref, observers_ref, - self.__class__, - **self.__dict__, + self, ) ) @@ -393,8 +388,7 @@ def generate_ephemeris( variant_chunk_idx, variants_ref, observers_ref, - self.__class__, - **self.__dict__, + self, ) ) @@ -466,6 +460,48 @@ def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitTyp """ pass + def __getstate__(self): + """ + Get the state of the propagator. + + Subclasses need to define what is picklable for multiprocessing. + + e.g. + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_stateful_attribute_that_is_not_pickleable") + return state + """ + raise NotImplementedError( + "Propagator must implement __getstate__ for multiprocessing serialization.\n" + "Example implementation: \n" + "def __getstate__(self):\n" + " state = self.__dict__.copy()\n" + " state.pop('_stateful_attribute_that_is_not_pickleable')\n" + " return state" + ) + + def __setstate__(self, state): + """ + Set the state of the propagator. + + Subclasses need to define what is unpicklable for multiprocessing. + + e.g. + + def __setstate__(self, state): + self.__dict__.update(state) + self._stateful_attribute_that_is_not_pickleable = None + """ + raise NotImplementedError( + "Propagator must implement __setstate__ for multiprocessing serialization.\n" + "Example implementation: \n" + "def __setstate__(self, state):\n" + " self.__dict__.update(state)\n" + " self._stateful_attribute_that_is_not_pickleable = None" + ) + def propagate_orbits( self, orbits: Union[OrbitType, ObjectRef], @@ -551,8 +587,7 @@ def propagate_orbits( idx_chunk, orbits_ref, times_ref, - self.__class__, - **self.__dict__, + self, ) ) @@ -574,8 +609,7 @@ def propagate_orbits( variant_chunk_idx, variants_ref, times_ref, - self.__class__, - **self.__dict__, + self, ) ) diff --git a/src/adam_core/propagator/tests/test_propagator.py b/src/adam_core/propagator/tests/test_propagator.py index b4ee6f90..9271e212 100644 --- a/src/adam_core/propagator/tests/test_propagator.py +++ b/src/adam_core/propagator/tests/test_propagator.py @@ -17,6 +17,14 @@ class MockPropagator(Propagator, EphemerisMixin): + + def __getstate__(self): + state = self.__dict__.copy() + return state + + def __setstate__(self, state): + self.__dict__.update(state) + # MockPropagator propagates orbits by just setting the time of the orbits. def _propagate_orbits(self, orbits: Orbits, times: Timestamp) -> Orbits: all_times = []