-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
201 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import logging | ||
from typing import TYPE_CHECKING, Callable, Optional | ||
|
||
import numpy as np | ||
|
||
from bsk_rl.data.base import Data, DataStore, GlobalReward | ||
from bsk_rl.sats import Satellite | ||
from bsk_rl.scene.scenario import Scenario | ||
|
||
if TYPE_CHECKING: | ||
from bsk_rl.sats import Satellite | ||
from bsk_rl.scene.targets import Target | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ComposedData(Data): | ||
"""Data for composed data types.""" | ||
|
||
def __init__(self, *data: Data) -> None: | ||
"""Construct composed data. | ||
Args: | ||
data: Dictionary of data types to compose. | ||
""" | ||
self.data = data | ||
|
||
def __add__(self, other: "ComposedData") -> "ComposedData": | ||
"""Combine two units of composed data. | ||
Args: | ||
other: Another unit of composed data to combine with this one. | ||
Returns: | ||
Combined unit of composed data. | ||
""" | ||
if len(self.data) == 0 and len(other.data) == 0: | ||
data = [] | ||
elif len(self.data) == 0: | ||
data = [type(d)() + d for d in other.data] | ||
elif len(other.data) == 0: | ||
data = [d + type(d)() for d in self.data] | ||
elif len(self.data) == len(other.data): | ||
data = [d1 + d2 for d1, d2 in zip(self.data, other.data)] | ||
else: | ||
raise ValueError( | ||
"ComposedData units must have the same number of data types." | ||
) | ||
return ComposedData(*data) | ||
|
||
def __getattr__(self, name: str): | ||
for data in self.data: | ||
if hasattr(data, name): | ||
return getattr(data, name) | ||
raise AttributeError(f"No Data in ComposedData has attribute '{name}'") | ||
|
||
|
||
class ComposedDataStore(DataStore): | ||
data_type = ComposedData | ||
|
||
def pass_data(self) -> Data: | ||
for ds, data in zip(self.datastores, self.data.data): | ||
ds.data = data | ||
|
||
def __init__( | ||
self, | ||
satellite: "Satellite", | ||
*datastore_types: type[DataStore], | ||
initial_data: ComposedData = None, | ||
): | ||
super().__init__(satellite, initial_data) | ||
self.datastores = [ds(satellite) for ds in datastore_types] | ||
self.pass_data() | ||
|
||
def __getattr__(self, name: str): | ||
for datastore in self.datastores: | ||
if hasattr(datastore, name): | ||
return getattr(datastore, name) | ||
raise AttributeError( | ||
f"No DataStore in ComposedDataStore has attribute '{name}'" | ||
) | ||
|
||
def get_log_state(self) -> list: | ||
log_states = [ds.get_log_state() for ds in self.datastores] | ||
return log_states | ||
|
||
def compare_log_states(self, prev_state: list, new_state: list) -> Data: | ||
data = [ | ||
ds.compare_log_states(prev, new) | ||
for ds, prev, new in zip(self.datastores, prev_state, new_state) | ||
] | ||
return ComposedData(*data) | ||
|
||
def update_from_logs(self) -> Data: | ||
new_data = super().update_from_logs() | ||
self.pass_data() | ||
return new_data | ||
|
||
def update_with_communicated_data(self) -> None: | ||
super().update_with_communicated_data() | ||
self.pass_data() | ||
|
||
|
||
class ComposedReward(GlobalReward): | ||
datastore_type = ComposedDataStore | ||
|
||
def pass_data(self) -> Data: | ||
for rewarder, data in zip(self.rewarders, self.data.data): | ||
rewarder.data = data | ||
|
||
def __init__(self, *rewarders: GlobalReward) -> None: | ||
"""Construct composed reward. | ||
Args: | ||
rewards: Global rewards to compose. | ||
""" | ||
super().__init__() | ||
self.rewarders = rewarders | ||
|
||
def __getattr__(self, name: str): | ||
for rewarder in self.rewarders: | ||
if hasattr(rewarder, name): | ||
return getattr(rewarder, name) | ||
raise AttributeError( | ||
f"No GlobalReward in ComposedReward has attribute '{name}'" | ||
) | ||
|
||
def reset_pre_sim_init(self) -> None: | ||
super().reset_pre_sim_init() | ||
for rewarder in self.rewarders: | ||
rewarder.reset_pre_sim_init() | ||
|
||
def reset_post_sim_init(self) -> None: | ||
super().reset_post_sim_init() | ||
for rewarder in self.rewarders: | ||
rewarder.reset_post_sim_init() | ||
|
||
def reset_during_sim_init(self) -> None: | ||
super().reset_during_sim_init() | ||
for rewarder in self.rewarders: | ||
rewarder.reset_during_sim_init() | ||
|
||
def reset_overwrite_previous(self) -> None: | ||
super().reset_overwrite_previous() | ||
for rewarder in self.rewarders: | ||
rewarder.reset_overwrite_previous() | ||
|
||
def link_scenario(self, scenario: Scenario) -> None: | ||
super().link_scenario(scenario) | ||
for rewarder in self.rewarders: | ||
rewarder.link_scenario(scenario) | ||
|
||
def initial_data(self, satellite: Satellite) -> ComposedData: | ||
return ComposedData( | ||
*[rewarder.initial_data(satellite) for rewarder in self.rewarders] | ||
) | ||
|
||
def create_data_store(self, satellite: Satellite) -> None: | ||
# TODO support passing kwargs | ||
satellite.data_store = ComposedDataStore( | ||
satellite, | ||
*[r.datastore_type for r in self.rewarders], | ||
initial_data=self.initial_data(satellite), | ||
) | ||
self.cum_reward[satellite.name] = 0.0 | ||
|
||
def calculate_reward( | ||
self, new_data_dict: dict[str, ComposedData] | ||
) -> dict[str, float]: | ||
data_len = len(list(new_data_dict.values())[0].data) | ||
|
||
for data in new_data_dict.values(): | ||
assert len(data.data) == data_len | ||
|
||
reward = {} | ||
if data_len != 0: | ||
for i, rewarder in enumerate(self.rewarders): | ||
reward_i = rewarder.calculate_reward( | ||
{sat_id: data.data[i] for sat_id, data in new_data_dict.items()} | ||
) | ||
|
||
# Logging | ||
nonzero_reward = {k: v for k, v in reward_i.items() if v != 0} | ||
if len(nonzero_reward) > 0: | ||
logger.info(f"{type(rewarder).__name__} reward: {nonzero_reward}") | ||
|
||
for sat_id, sat_reward in reward_i.items(): | ||
reward[sat_id] = reward.get(sat_id, 0.0) + sat_reward | ||
return reward | ||
|
||
def reward(self, new_data_dict: dict[str, ComposedData]) -> dict[str, float]: | ||
reward = super().reward(new_data_dict) | ||
self.pass_data() | ||
return reward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters