Skip to content

Commit

Permalink
Issue #0: Minor changes to data system
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark2000 committed Sep 23, 2024
1 parent 7198344 commit 3e8eaff
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
9 changes: 7 additions & 2 deletions src/bsk_rl/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,15 @@ def initial_data(self, satellite: "Satellite") -> "Data":
"""Furnish the :class:`~bsk_rl.data.base.DataStore` with initial data."""
return self.data_type()

def create_data_store(self, satellite: "Satellite") -> None:
def create_data_store(self, satellite: "Satellite", **data_store_kwargs) -> None:
"""Create a data store for a satellite.
Args:
satellite: Satellite to create a data store for.
data_store_kwargs: Additional keyword arguments to pass to the data store
"""
satellite.data_store = self.datastore_type(
satellite, initial_data=self.initial_data(satellite)
satellite, initial_data=self.initial_data(satellite), **data_store_kwargs
)
self.cum_reward[satellite.name] = 0.0

Expand Down Expand Up @@ -183,6 +184,10 @@ def reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]:
reward = self.calculate_reward(new_data_dict)
for satellite_id, sat_reward in reward.items():
self.cum_reward[satellite_id] += sat_reward

for new_data in new_data_dict.values():
self.data += new_data

logger.info(f"Data reward: {reward}")
return reward

Expand Down
2 changes: 0 additions & 2 deletions src/bsk_rl/data/unique_image_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ def calculate_reward(
target.priority
) / imaged_targets.count(target)

for new_data in new_data_dict.values():
self.data += new_data
return reward


Expand Down
1 change: 1 addition & 0 deletions tests/unittest/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_create_data_store(self):

def test_reward(self):
dm = GlobalReward()
dm.reset_overwrite_previous()
dm.calculate_reward = MagicMock(return_value={"sat": 10.0})
dm.cum_reward = {"sat": 5.0}
assert {"sat": 10.0} == dm.reward({"sat": "data"})
Expand Down

0 comments on commit 3e8eaff

Please sign in to comment.