Skip to content

Commit

Permalink
Short-term fix for len() parsing a one-value tuple as len(value).
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeyers committed Sep 12, 2024
1 parent 506fc87 commit 6f11104
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
12 changes: 6 additions & 6 deletions og_marl/vault_utils/analyse_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def get_structure_descriptors(
experience: Dict[str, Array], n_head: int = 1, done_flags: tuple = ("terminals"),
experience: Dict[str, Array], n_head: int = 1, done_flags: tuple = ("terminals",),
) -> Tuple[Dict[str, Array], Dict[str, Array], int]:
struct = jax.tree_map(lambda x: x.shape, experience)

Expand Down Expand Up @@ -54,7 +54,7 @@ def describe_structure(
vault_uids: Optional[List[str]] = None,
rel_dir: str = "vaults",
n_head: int = 0,
done_flags: tuple = ("terminals"),
done_flags: tuple = ("terminals",),
) -> Dict[str, Array]:
# get all uids if not specified
if vault_uids is None:
Expand Down Expand Up @@ -87,7 +87,7 @@ def describe_structure(


def get_episode_return_descriptors(
experience: Dict[str, Array], done_flags: tuple = ("terminals"),
experience: Dict[str, Array], done_flags: tuple = ("terminals",),
) -> Tuple[float, float, float, float, Array]:
episode_returns = calculate_returns(experience, done_flags = done_flags)

Expand Down Expand Up @@ -163,7 +163,7 @@ def describe_episode_returns(
save_violin: bool = False,
plot_saving_rel_dir: str = "vaults",
n_bins: Optional[int] = 50,
done_flags: tuple = ("terminals"),
done_flags: tuple = ("terminals",),
) -> None:
"""Describe a vault.
Expand Down Expand Up @@ -222,7 +222,7 @@ def describe_episode_returns(


def calculate_returns(
experience: Dict[str, Array], reward_key: str = "rewards", done_flags: tuple = ("terminals"),
experience: Dict[str, Array], reward_key: str = "rewards", done_flags: tuple = ("terminals",),
) -> Array:
"""Calculate the returns in a dataset of experience.
Expand Down Expand Up @@ -409,7 +409,7 @@ def descriptive_summary(
plot_hist: bool = True,
save_hist: bool = False,
n_bins: int = 40,
done_flags: tuple = ("terminals"),
done_flags: tuple = ("terminals",),
) -> Dict[str, Array]:
"""Provides coverage, structural and episode return descriptors of a Vault of datasets.
Expand Down
6 changes: 3 additions & 3 deletions og_marl/vault_utils/subsample_similar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


# cumulative summing per-episode
def get_episode_returns_and_term_idxes(offline_data: Dict[str, Array], done_flags: tuple = ("terminals")) -> Tuple[Array, Array]:
def get_episode_returns_and_term_idxes(offline_data: Dict[str, Array], done_flags: tuple = ("terminals",)) -> Tuple[Array, Array]:
"""Gets the episode returns and final indices from a batch of experience.
From a batch of experience extract the indices
Expand All @@ -41,7 +41,7 @@ def get_episode_returns_and_term_idxes(offline_data: Dict[str, Array], done_flag

terminal_flag = jnp.logical_or(done_1,done_2)

assert bool(terminal_flag[-1]) is True
# assert bool(terminal_flag[-1]) is True

def scan_cumsum(
return_so_far: float, prev_term_reward: Tuple[bool, float]
Expand Down Expand Up @@ -121,7 +121,7 @@ def subsample_similar(
second_vault_info: Dict[str, str],
new_rel_dir: str,
new_vault_name: str,
done_flags: tuple = ("terminals"),
done_flags: tuple = ("terminals",),
) -> None:
"""Subsamples 2 datasets s.t. the new datasets have similar episode return distributions."""
# check that a subsampled vault by the same name does not already exist
Expand Down
13 changes: 10 additions & 3 deletions og_marl/vault_utils/subsample_smaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Optional

import jax
import jax.numpy as jnp
import pickle
import numpy as np
import flashbax as fbx
Expand All @@ -30,15 +31,21 @@
# subsample vault smaller


def get_length_start_end(experience: Dict[str, Array], terminal_key: str = "terminals") -> Array:
def get_length_start_end(experience: Dict[str, Array], done_flags: tuple = ("terminals",)) -> Array:
"""Process experience to get the length, start and end of all episodes.
From a block of experience, extracts the length, start position and end position of each
episode. Length is stored for the convenience of a cumsum in the following function, and
to match other episode information blocks which store return, instead.
"""
# extract terminals
terminal_flag = experience[terminal_key][0, :, ...].all(axis=-1)
# extract episode ends, could be term or trunc
if len(done_flags)==1:
terminal_flag = experience[done_flags[0]][0, :, ...].all(axis=-1)
elif len(done_flags)==2:
done_1 = experience[done_flags[0]][0, :, ...].all(axis=-1)
done_2 = experience[done_flags[1]][0, :, ...].all(axis=-1)

terminal_flag = jnp.logical_or(done_1,done_2)

# list of indices of terminal entries
term_idxes = np.argwhere(terminal_flag)
Expand Down

0 comments on commit 6f11104

Please sign in to comment.