Skip to content

Commit

Permalink
Merge branch 'main' of github.com:instadeepai/og-marl
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeyers committed Sep 11, 2024
2 parents df53d97 + 6d0210d commit 506fc87
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions og_marl/vault_utils/analyse_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,24 @@


def get_structure_descriptors(
experience: Dict[str, Array], n_head: int = 1
experience: Dict[str, Array], n_head: int = 1, done_flags: tuple = ("terminals"),
) -> Tuple[Dict[str, Array], Dict[str, Array], int]:
struct = jax.tree_map(lambda x: x.shape, experience)

head = jax.tree_map(lambda x: x[0, :n_head, ...], experience)

terminal_flag = experience["terminals"][0, :, ...].all(axis=-1)
# allow for "terminals" and "truncations" to be combined into one "done"
if len(done_flags)==1:
terminal_flag = experience[done_flags[0]][0, :, ...].all(axis=-1) # .all is for all agents
elif len(done_flags)==2:
done_1 = experience[done_flags[0]][0, :, ...].all(axis=-1)
done_2 = experience[done_flags[1]][0, :, ...].all(axis=-1)

terminal_flag = jnp.logical_or(done_1,done_2)
else:
print("Too many done flags. Please revise.")
return struct, head, 0

num_episodes = int(jnp.sum(terminal_flag))

return struct, head, num_episodes
Expand All @@ -43,6 +54,7 @@ def describe_structure(
vault_uids: Optional[List[str]] = None,
rel_dir: str = "vaults",
n_head: int = 0,
done_flags: tuple = ("terminals"),
) -> Dict[str, Array]:
# get all uids if not specified
if vault_uids is None:
Expand All @@ -57,7 +69,7 @@ def describe_structure(
exp = vlt.read().experience
n_trans = exp["actions"].shape[1]

struct, head, n_traj = get_structure_descriptors(exp, n_head)
struct, head, n_traj = get_structure_descriptors(exp, n_head, done_flags)

print(str(uid) + "\n-----")
for key, val in struct.items():
Expand All @@ -75,9 +87,9 @@ def describe_structure(


def get_episode_return_descriptors(
experience: Dict[str, Array],
experience: Dict[str, Array], done_flags: tuple = ("terminals"),
) -> Tuple[float, float, float, float, Array]:
episode_returns = calculate_returns(experience)
episode_returns = calculate_returns(experience, done_flags = done_flags)

mean = jnp.mean(episode_returns)
stddev = jnp.std(episode_returns)
Expand Down Expand Up @@ -151,6 +163,7 @@ def describe_episode_returns(
save_violin: bool = False,
plot_saving_rel_dir: str = "vaults",
n_bins: Optional[int] = 50,
done_flags: tuple = ("terminals"),
) -> None:
"""Describe a vault.
Expand All @@ -170,7 +183,7 @@ def describe_episode_returns(
vlt = Vault(vault_name=vault_name, rel_dir=rel_dir, vault_uid=uid)
exp = vlt.read().experience

mean, stddev, max_ret, min_ret, episode_returns = get_episode_return_descriptors(exp)
mean, stddev, max_ret, min_ret, episode_returns = get_episode_return_descriptors(exp, done_flags)
all_uid_eps_returns[uid] = episode_returns

single_values.append([uid, mean, stddev, max_ret, min_ret])
Expand Down Expand Up @@ -209,7 +222,7 @@ def describe_episode_returns(


def calculate_returns(
experience: Dict[str, Array], reward_key: str = "rewards", terminal_key: str = "terminals"
experience: Dict[str, Array], reward_key: str = "rewards", done_flags: tuple = ("terminals"),
) -> Array:
"""Calculate the returns in a dataset of experience.
Expand All @@ -229,7 +242,13 @@ def calculate_returns(
# We want all the time data, but just from one agent
experience_one_agent = jax.tree_map(lambda x: x[0, :, 0, ...], experience)
rewards = experience_one_agent[reward_key]
terminals = jnp.array(experience[terminal_key][0].all(axis=-1).squeeze(), dtype=jnp.float32)

if len(done_flags)==1:
terminals = jnp.array(experience[done_flags[0]][0].all(axis=-1).squeeze(), dtype=jnp.float32) # .all is for all agents
elif len(done_flags)==2:
done_1 = jnp.array(experience[done_flags[0]][0].all(axis=-1).squeeze(), dtype=jnp.float32)
done_2 = jnp.array(experience[done_flags[1]][0].all(axis=-1).squeeze(), dtype=jnp.float32)
terminals = jnp.logical_or(done_1,done_2).astype(jnp.float32)

def sum_rewards(terminals: Array, rewards: Array) -> Array:
def scan_fn(carry: Array, inputs: Array) -> Array:
Expand Down Expand Up @@ -390,6 +409,7 @@ def descriptive_summary(
plot_hist: bool = True,
save_hist: bool = False,
n_bins: int = 40,
done_flags: tuple = ("terminals"),
) -> Dict[str, Array]:
"""Provides coverage, structural and episode return descriptors of a Vault of datasets.
Expand Down Expand Up @@ -430,7 +450,7 @@ def descriptive_summary(
exp = vlt.read().experience

saco, _, _ = get_saco(exp)
mean, stddev, max_ret, min_ret, episode_returns = get_episode_return_descriptors(exp)
mean, stddev, max_ret, min_ret, episode_returns = get_episode_return_descriptors(exp, done_flags)
n_traj = len(episode_returns)
n_trans = exp["actions"].shape[1]

Expand Down

0 comments on commit 506fc87

Please sign in to comment.