Skip to content

Commit

Permalink
feat: log win rate in case it exists
Browse files Browse the repository at this point in the history
  • Loading branch information
OmaymaMahjoub committed Dec 18, 2023
1 parent bb82393 commit 7c30af9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
10 changes: 9 additions & 1 deletion mava/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __call__(
def get_logger_tools(logger: Logger) -> LogFn: # noqa: CCR001
"""Get the logger function."""

def log(
def log( # noqa: CCR001
metrics: ExperimentOutput,
t_env: int = 0,
trainer_metric: bool = False,
Expand Down Expand Up @@ -69,6 +69,10 @@ def log(
prefix = "evaluator/"
episodes_info = metrics.episodes_info

# Add win rate to episodes_info in case it exists.
if "won_episode" in episodes_info:
win_rate = (jnp.sum(episodes_info["won_episode"]) / logger.num_eval_episodes) * 100

# Flatten metrics info.
episodes_return = jnp.ravel(episodes_info["episode_return"])
episodes_length = jnp.ravel(episodes_info["episode_length"])
Expand All @@ -89,6 +93,8 @@ def log(
logger.log_stat(f"{prefix}value_loss", float(np.mean(value_loss)), t_env)
logger.log_stat(f"{prefix}loss_actor", float(np.mean(loss_actor)), t_env)
logger.log_stat(f"{prefix}entropy", float(np.mean(entropy)), t_env)
if "won_episode" in episodes_info:
logger.log_stat(f"{prefix}win_rate", float(win_rate), t_env, eval_step)

log_string = (
f"Timesteps {t_env:07d} | "
Expand All @@ -100,6 +106,8 @@ def log(
f"Max Episode Length {float(np.max(episodes_length)):.3f} | "
f"Steps Per Second {steps_per_second:.2e} "
)
if "won_episode" in episodes_info:
log_string += f"| Win Rate {win_rate:.2f}%"

if absolute_metric:
logger.console_logger.info(
Expand Down
17 changes: 13 additions & 4 deletions mava/utils/logger_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, cfg: Dict) -> None:
self.should_log = bool(
cfg["logger"]["use_json"] or cfg["logger"]["use_tf"] or cfg["logger"]["use_neptune"]
)
self.num_eval_episodes = cfg["arch"]["num_eval_episodes"]

def _setup_tb(self, cfg: Dict) -> None:
"""Set up tensorboard logging."""
Expand All @@ -77,11 +78,15 @@ def _setup_json(self, cfg: Dict) -> None:
json_logs_path = os.path.join(
cfg["logger"]["base_exp_path"], "json", cfg["logger"]["kwargs"]["json_path"]
)

task_name = (
cfg["env"]["scenario"]
if isinstance(cfg["env"]["scenario"], str)
else cfg["env"]["scenario"]["task_name"]
)
self.json_logger = JsonWriter(
path=json_logs_path,
algorithm_name=cfg["logger"]["system_name"],
task_name=cfg["env"]["scenario"]["task_name"],
task_name=task_name,
environment_name=cfg["env"]["env_name"],
seed=cfg["system"]["seed"],
)
Expand Down Expand Up @@ -140,10 +145,14 @@ def get_neptune_logger(cfg: Dict) -> neptune.Run:

def get_experiment_path(config: Dict, logger_type: str) -> str:
"""Helper function to create the experiment path."""
task_name = (
config["env"]["scenario"]
if isinstance(config["env"]["scenario"], str)
else config["env"]["scenario"]["task_name"]
)
exp_path = (
f"{logger_type}/{config['logger']['system_name']}/{config['env']['env_name']}/"
+ f"{config['env']['scenario']['task_name']}"
+ f"/envs_{config['arch']['num_envs']}/seed_{config['system']['seed']}"
+ f"{task_name}/envs_{config['arch']['num_envs']}/seed_{config['system']['seed']}"
)

return exp_path
Expand Down

0 comments on commit 7c30af9

Please sign in to comment.