Skip to content

Commit

Permalink
Issue #66: Rewarder composition
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark2000 committed Dec 27, 2024
1 parent 45b941e commit cdf224d
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/bsk_rl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@
"""

from bsk_rl.data.base import GlobalReward
from bsk_rl.data.composition import ComposedReward
from bsk_rl.data.nadir_data import ScanningTimeReward
from bsk_rl.data.no_data import NoReward
from bsk_rl.data.unique_image_data import UniqueImageReward

__doc_title__ = "Data & Reward"
__all__ = [
"GlobalReward",
"ComposedReward",
"NoReward",
"UniqueImageReward",
"ScanningTimeReward",
Expand Down
2 changes: 1 addition & 1 deletion src/bsk_rl/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]:
self.data += new_data

nonzero_reward = {k: v for k, v in reward.items() if v != 0}
logger.info(f"Data reward: {nonzero_reward}")
logger.info(f"Total reward: {nonzero_reward}")
return reward


Expand Down
194 changes: 194 additions & 0 deletions src/bsk_rl/data/composition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import logging
from typing import TYPE_CHECKING, Callable, Optional

import numpy as np

from bsk_rl.data.base import Data, DataStore, GlobalReward
from bsk_rl.sats import Satellite
from bsk_rl.scene.scenario import Scenario

if TYPE_CHECKING:
from bsk_rl.sats import Satellite
from bsk_rl.scene.targets import Target

logger = logging.getLogger(__name__)


class ComposedData(Data):
"""Data for composed data types."""

def __init__(self, *data: Data) -> None:
"""Construct composed data.
Args:
data: Dictionary of data types to compose.
"""
self.data = data

def __add__(self, other: "ComposedData") -> "ComposedData":
"""Combine two units of composed data.
Args:
other: Another unit of composed data to combine with this one.
Returns:
Combined unit of composed data.
"""
if len(self.data) == 0 and len(other.data) == 0:
data = []
elif len(self.data) == 0:
data = [type(d)() + d for d in other.data]
elif len(other.data) == 0:
data = [d + type(d)() for d in self.data]
elif len(self.data) == len(other.data):
data = [d1 + d2 for d1, d2 in zip(self.data, other.data)]
else:
raise ValueError(
"ComposedData units must have the same number of data types."
)
return ComposedData(*data)

def __getattr__(self, name: str):
for data in self.data:
if hasattr(data, name):
return getattr(data, name)
raise AttributeError(f"No Data in ComposedData has attribute '{name}'")


class ComposedDataStore(DataStore):
data_type = ComposedData

def pass_data(self) -> Data:
for ds, data in zip(self.datastores, self.data.data):
ds.data = data

def __init__(
self,
satellite: "Satellite",
*datastore_types: type[DataStore],
initial_data: ComposedData = None,
):
super().__init__(satellite, initial_data)
self.datastores = [ds(satellite) for ds in datastore_types]
self.pass_data()

def __getattr__(self, name: str):
for datastore in self.datastores:
if hasattr(datastore, name):
return getattr(datastore, name)
raise AttributeError(
f"No DataStore in ComposedDataStore has attribute '{name}'"
)

def get_log_state(self) -> list:
log_states = [ds.get_log_state() for ds in self.datastores]
return log_states

def compare_log_states(self, prev_state: list, new_state: list) -> Data:
data = [
ds.compare_log_states(prev, new)
for ds, prev, new in zip(self.datastores, prev_state, new_state)
]
return ComposedData(*data)

def update_from_logs(self) -> Data:
new_data = super().update_from_logs()
self.pass_data()
return new_data

def update_with_communicated_data(self) -> None:
super().update_with_communicated_data()
self.pass_data()


class ComposedReward(GlobalReward):
datastore_type = ComposedDataStore

def pass_data(self) -> Data:
for rewarder, data in zip(self.rewarders, self.data.data):
rewarder.data = data

def __init__(self, *rewarders: GlobalReward) -> None:
"""Construct composed reward.
Args:
rewards: Global rewards to compose.
"""
super().__init__()
self.rewarders = rewarders

def __getattr__(self, name: str):
for rewarder in self.rewarders:
if hasattr(rewarder, name):
return getattr(rewarder, name)
raise AttributeError(
f"No GlobalReward in ComposedReward has attribute '{name}'"
)

def reset_pre_sim_init(self) -> None:
super().reset_pre_sim_init()
for rewarder in self.rewarders:
rewarder.reset_pre_sim_init()

def reset_post_sim_init(self) -> None:
super().reset_post_sim_init()
for rewarder in self.rewarders:
rewarder.reset_post_sim_init()

def reset_during_sim_init(self) -> None:
super().reset_during_sim_init()
for rewarder in self.rewarders:
rewarder.reset_during_sim_init()

def reset_overwrite_previous(self) -> None:
super().reset_overwrite_previous()
for rewarder in self.rewarders:
rewarder.reset_overwrite_previous()

def link_scenario(self, scenario: Scenario) -> None:
super().link_scenario(scenario)
for rewarder in self.rewarders:
rewarder.link_scenario(scenario)

def initial_data(self, satellite: Satellite) -> ComposedData:
return ComposedData(
*[rewarder.initial_data(satellite) for rewarder in self.rewarders]
)

def create_data_store(self, satellite: Satellite) -> None:
# TODO support passing kwargs
satellite.data_store = ComposedDataStore(
satellite,
*[r.datastore_type for r in self.rewarders],
initial_data=self.initial_data(satellite),
)
self.cum_reward[satellite.name] = 0.0

def calculate_reward(
self, new_data_dict: dict[str, ComposedData]
) -> dict[str, float]:
data_len = len(list(new_data_dict.values())[0].data)

for data in new_data_dict.values():
assert len(data.data) == data_len

reward = {}
if data_len != 0:
for i, rewarder in enumerate(self.rewarders):
reward_i = rewarder.calculate_reward(
{sat_id: data.data[i] for sat_id, data in new_data_dict.items()}
)

# Logging
nonzero_reward = {k: v for k, v in reward_i.items() if v != 0}
if len(nonzero_reward) > 0:
logger.info(f"{type(rewarder).__name__} reward: {nonzero_reward}")

for sat_id, sat_reward in reward_i.items():
reward[sat_id] = reward.get(sat_id, 0.0) + sat_reward
return reward

def reward(self, new_data_dict: dict[str, ComposedData]) -> dict[str, float]:
reward = super().reward(new_data_dict)
self.pass_data()
return reward
6 changes: 4 additions & 2 deletions src/bsk_rl/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pettingzoo.utils.env import AgentID, ParallelEnv

from bsk_rl.comm import CommunicationMethod, NoCommunication
from bsk_rl.data import GlobalReward, NoReward
from bsk_rl.data import ComposedReward, GlobalReward, NoReward
from bsk_rl.sats import Satellite
from bsk_rl.scene import Scenario
from bsk_rl.sim import Simulator
Expand All @@ -36,7 +36,7 @@ def __init__(
self,
satellites: Union[Satellite, list[Satellite]],
scenario: Optional[Scenario] = None,
rewarder: Optional[GlobalReward] = None,
rewarder: Optional[Union[GlobalReward, list[GlobalReward]]] = None,
world_type: Optional[type[WorldModel]] = None,
world_args: Optional[dict[str, Any]] = None,
communicator: Optional[CommunicationMethod] = None,
Expand Down Expand Up @@ -127,6 +127,8 @@ def __init__(
scenario = Scenario()
if rewarder is None:
rewarder = NoReward()
if isinstance(rewarder, Iterable):
rewarder = ComposedReward(*rewarder)

if world_type is None:
world_type = self._minimum_world_model()
Expand Down

0 comments on commit cdf224d

Please sign in to comment.