diff --git a/og_marl/vault_utils/analyse_vault.py b/og_marl/vault_utils/analyse_vault.py index e47554a..6a2ccbd 100644 --- a/og_marl/vault_utils/analyse_vault.py +++ b/og_marl/vault_utils/analyse_vault.py @@ -26,13 +26,24 @@ def get_structure_descriptors( - experience: Dict[str, Array], n_head: int = 1 + experience: Dict[str, Array], n_head: int = 1, done_flags: tuple = ("terminals"), ) -> Tuple[Dict[str, Array], Dict[str, Array], int]: struct = jax.tree_map(lambda x: x.shape, experience) head = jax.tree_map(lambda x: x[0, :n_head, ...], experience) - terminal_flag = experience["terminals"][0, :, ...].all(axis=-1) + # allow for "terminals" and "truncations" to be combined into one "done" + if len(done_flags)==1: + terminal_flag = experience[done_flags[0]][0, :, ...].all(axis=-1) # .all is for all agents + elif len(done_flags)==2: + done_1 = experience[done_flags[0]][0, :, ...].all(axis=-1) + done_2 = experience[done_flags[1]][0, :, ...].all(axis=-1) + + terminal_flag = jnp.logical_or(done_1,done_2) + else: + print("Too many done flags. Please revise.") + return struct, head, 0 + num_episodes = int(jnp.sum(terminal_flag)) return struct, head, num_episodes @@ -43,6 +54,7 @@ def describe_structure( vault_uids: Optional[List[str]] = None, rel_dir: str = "vaults", n_head: int = 0, + done_flags: tuple = ("terminals"), ) -> Dict[str, Array]: # get all uids if not specified if vault_uids is None: @@ -57,7 +69,7 @@ def describe_structure( exp = vlt.read().experience n_trans = exp["actions"].shape[1] - struct, head, n_traj = get_structure_descriptors(exp, n_head) + struct, head, n_traj = get_structure_descriptors(exp, n_head, done_flags) print(str(uid) + "\n-----") for key, val in struct.items(): @@ -75,9 +87,9 @@ def describe_structure( def get_episode_return_descriptors( - experience: Dict[str, Array], + experience: Dict[str, Array], done_flags: tuple = ("terminals"), ) -> Tuple[float, float, float, float, Array]: - episode_returns = calculate_returns(experience) + episode_returns = calculate_returns(experience, done_flags = done_flags) mean = jnp.mean(episode_returns) stddev = jnp.std(episode_returns) @@ -151,6 +163,7 @@ def describe_episode_returns( save_violin: bool = False, plot_saving_rel_dir: str = "vaults", n_bins: Optional[int] = 50, + done_flags: tuple = ("terminals"), ) -> None: """Describe a vault. @@ -170,7 +183,7 @@ def describe_episode_returns( vlt = Vault(vault_name=vault_name, rel_dir=rel_dir, vault_uid=uid) exp = vlt.read().experience - mean, stddev, max_ret, min_ret, episode_returns = get_episode_return_descriptors(exp) + mean, stddev, max_ret, min_ret, episode_returns = get_episode_return_descriptors(exp, done_flags) all_uid_eps_returns[uid] = episode_returns single_values.append([uid, mean, stddev, max_ret, min_ret]) @@ -209,7 +222,7 @@ def describe_episode_returns( def calculate_returns( - experience: Dict[str, Array], reward_key: str = "rewards", terminal_key: str = "terminals" + experience: Dict[str, Array], reward_key: str = "rewards", done_flags: tuple = ("terminals"), ) -> Array: """Calculate the returns in a dataset of experience. @@ -229,7 +242,13 @@ def calculate_returns( # We want all the time data, but just from one agent experience_one_agent = jax.tree_map(lambda x: x[0, :, 0, ...], experience) rewards = experience_one_agent[reward_key] - terminals = jnp.array(experience[terminal_key][0].all(axis=-1).squeeze(), dtype=jnp.float32) + + if len(done_flags)==1: + terminals = jnp.array(experience[done_flags[0]][0].all(axis=-1).squeeze(), dtype=jnp.float32) # .all is for all agents + elif len(done_flags)==2: + done_1 = jnp.array(experience[done_flags[0]][0].all(axis=-1).squeeze(), dtype=jnp.float32) + done_2 = jnp.array(experience[done_flags[1]][0].all(axis=-1).squeeze(), dtype=jnp.float32) + terminals = jnp.logical_or(done_1,done_2).astype(jnp.float32) def sum_rewards(terminals: Array, rewards: Array) -> Array: def scan_fn(carry: Array, inputs: Array) -> Array: @@ -390,6 +409,7 @@ def descriptive_summary( plot_hist: bool = True, save_hist: bool = False, n_bins: int = 40, + done_flags: tuple = ("terminals"), ) -> Dict[str, Array]: """Provides coverage, structural and episode return descriptors of a Vault of datasets. @@ -430,7 +450,7 @@ def descriptive_summary( exp = vlt.read().experience saco, _, _ = get_saco(exp) - mean, stddev, max_ret, min_ret, episode_returns = get_episode_return_descriptors(exp) + mean, stddev, max_ret, min_ret, episode_returns = get_episode_return_descriptors(exp, done_flags) n_traj = len(episode_returns) n_trans = exp["actions"].shape[1]