From 3e8eaffb0cde96adfd22b6dbf789c1aa87fe485d Mon Sep 17 00:00:00 2001 From: Mark Stephenson Date: Wed, 18 Sep 2024 16:26:06 -0600 Subject: [PATCH] Issue #0: Minor changes to data system --- src/bsk_rl/data/base.py | 9 +++++++-- src/bsk_rl/data/unique_image_data.py | 2 -- tests/unittest/data/test_data.py | 1 + 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/bsk_rl/data/base.py b/src/bsk_rl/data/base.py index b00f399b..b88b1c15 100644 --- a/src/bsk_rl/data/base.py +++ b/src/bsk_rl/data/base.py @@ -145,14 +145,15 @@ def initial_data(self, satellite: "Satellite") -> "Data": """Furnish the :class:`~bsk_rl.data.base.DataStore` with initial data.""" return self.data_type() - def create_data_store(self, satellite: "Satellite") -> None: + def create_data_store(self, satellite: "Satellite", **data_store_kwargs) -> None: """Create a data store for a satellite. Args: satellite: Satellite to create a data store for. + data_store_kwargs: Additional keyword arguments to pass to the data store """ satellite.data_store = self.datastore_type( - satellite, initial_data=self.initial_data(satellite) + satellite, initial_data=self.initial_data(satellite), **data_store_kwargs ) self.cum_reward[satellite.name] = 0.0 @@ -183,6 +184,10 @@ def reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]: reward = self.calculate_reward(new_data_dict) for satellite_id, sat_reward in reward.items(): self.cum_reward[satellite_id] += sat_reward + + for new_data in new_data_dict.values(): + self.data += new_data + logger.info(f"Data reward: {reward}") return reward diff --git a/src/bsk_rl/data/unique_image_data.py b/src/bsk_rl/data/unique_image_data.py index 194aef81..ead45eb7 100644 --- a/src/bsk_rl/data/unique_image_data.py +++ b/src/bsk_rl/data/unique_image_data.py @@ -183,8 +183,6 @@ def calculate_reward( target.priority ) / imaged_targets.count(target) - for new_data in new_data_dict.values(): - self.data += new_data return reward diff --git a/tests/unittest/data/test_data.py b/tests/unittest/data/test_data.py index 6c34c84e..c7069b75 100644 --- a/tests/unittest/data/test_data.py +++ b/tests/unittest/data/test_data.py @@ -62,6 +62,7 @@ def test_create_data_store(self): def test_reward(self): dm = GlobalReward() + dm.reset_overwrite_previous() dm.calculate_reward = MagicMock(return_value={"sat": 10.0}) dm.cum_reward = {"sat": 5.0} assert {"sat": 10.0} == dm.reward({"sat": "data"})