From 70b3547c65c0a555787173137b201dcd7aae63d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <48008469+puyuan1996@users.noreply.github.com> Date: Sat, 9 Nov 2024 21:56:09 +0800 Subject: [PATCH] feature(pu): add atari100k metric utils (#295) * feature(pu): add compute_normalized_mean_and_median_atari100k.py * feature(pu): add read_atari_return_from_txt.py * polish(pu): polish comments and docstring in atari100k utils --- ...te_normalized_mean_and_median_atari100k.py | 318 ++++++++++++++++++ zoo/atari/read_atari_results_from_txt.py | 162 +++++++++ 2 files changed, 480 insertions(+) create mode 100644 zoo/atari/compute_normalized_mean_and_median_atari100k.py create mode 100644 zoo/atari/read_atari_results_from_txt.py diff --git a/zoo/atari/compute_normalized_mean_and_median_atari100k.py b/zoo/atari/compute_normalized_mean_and_median_atari100k.py new file mode 100644 index 000000000..61247a7ca --- /dev/null +++ b/zoo/atari/compute_normalized_mean_and_median_atari100k.py @@ -0,0 +1,318 @@ +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +def compute_normalized_mean_and_median( + random_scores: list[float], + human_scores: list[float], + algo_scores: list[float] +) -> tuple[float, float]: + """ + Computes the normalized mean and median based on random, human, and algorithm scores. + + Args: + random_scores (list[float]): List of random scores for each game. + human_scores (list[float]): List of human scores for each game. + algo_scores (list[float]): List of algorithm scores for each game. + + Returns: + tuple[float, float]: + - The mean of the normalized scores. + - The median of the normalized scores. + + Raises: + ValueError: If any list is empty or if the lengths of the input lists do not match. + """ + if not random_scores or not human_scores or not algo_scores: + raise ValueError("Input score lists must not be empty.") + if len(random_scores) != len(human_scores) or len(human_scores) != len(algo_scores): + raise ValueError("Input score lists must have the same length.") + + # Calculate normalized scores + normalized_scores = [ + (algo_score - random_score) / (human_score - random_score) + if human_score != random_score else 0 + for random_score, human_score, algo_score in zip(random_scores, human_scores, algo_scores) + ] + + # Compute mean and median of the normalized scores + normalized_mean = np.mean(normalized_scores) + normalized_median = np.median(normalized_scores) + + return normalized_mean, normalized_median + + +def plot_normalized_scores( + algorithms: list[str], + means: list[float], + medians: list[float], + filename: str = "normalized_scores.png" +) -> None: + """ + Plots a bar chart for normalized mean and median values for different algorithms. + + Args: + algorithms (list[str]): List of algorithm names. + means (list[float]): List of normalized mean values. + medians (list[float]): List of normalized median values. + filename (str, optional): Filename to save the plot (default is 'normalized_scores.png'). + + Returns: + None + + Raises: + ValueError: If lists of algorithms, means, or medians have different lengths. + + Example usage: + algorithms = ["Algorithm A", "Algorithm B", "Algorithm C"] + means = [0.75, 0.85, 0.60] + medians = [0.70, 0.80, 0.65] + plot_normalized_scores(algorithms, means, medians) + """ + if not (len(algorithms) == len(means) == len(medians)): + raise ValueError("Algorithms, means, and medians lists must have the same length.") + + # Set a style suited for academic papers (muted, professional colors) + sns.set(style="whitegrid") + + x = np.arange(len(algorithms)) # The label locations + width = 0.35 # Width of the bars + + # Set up the figure with a larger size (good for academic papers) + fig, ax = plt.subplots(figsize=(10, 6)) + + # Define color palette for consistency and readability + mean_color = sns.color_palette("muted")[0] # Muted blue + median_color = sns.color_palette("muted")[1] # Muted orange + + # Plotting bars for mean and median + bars_mean = ax.bar(x - width / 2, means, width, label='Normalized Mean', color=mean_color) + bars_median = ax.bar(x + width / 2, medians, width, label='Normalized Median', color=median_color) + + # Add labels, title, and custom x-axis tick labels + ax.set_ylabel('Scores', fontsize=14) + ax.set_title('Human Normalized Score (Atari 100k)', fontsize=16, pad=20) + ax.set_xticks(x) + ax.set_xticklabels(algorithms, fontsize=12) + ax.legend(fontsize=12) + + # Add grid lines for better readability + ax.yaxis.grid(True, linestyle='--', alpha=0.7) + + # Attach a text label above each bar displaying its height + def attach_labels(bars): + for bar in bars: + height = bar.get_height() + # Annotate with precision of two decimal places + ax.annotate(f'{height:.2f}', # Text to display + xy=(bar.get_x() + bar.get_width() / 2, height), # Position + xytext=(0, 3), # Offset from the top of the bar + textcoords="offset points", + ha='center', va='bottom', fontsize=11) + + attach_labels(bars_mean) + attach_labels(bars_median) + + # Adjust layout for tight fit (avoids cutting off labels) + fig.tight_layout() + + # Save the plot as a high-resolution image suitable for publications + plt.savefig(filename, dpi=300) + print(f"Plot saved as {filename}") + plt.close() + + +# Scores for Atari 100k 26 games +random_scores = [ + 227.8, # Alien + 5.8, # Amidar + 222.4, # Assault + 210.0, # Asterix + 14.2, # BankHeist + 2360.0, # BattleZone + 0.1, # Boxing + 1.7, # Breakout + 811.0, # ChopperCommand + 10780.5, # CrazyClimber + 152.1, # DemonAttack + 0.0, # Freeway + 65.2, # Frostbite + 257.6, # Gopher + 1027.0, # Hero + 29.0, # Jamesbond + 52.0, # Kangaroo + 1598.0, # Krull + 258.5, # KungFuMaster + 307.3, # MsPacman + -20.7, # Pong + 24.9, # PrivateEye + 163.9, # Qbert + 11.5, # RoadRunner + 68.4, # Seaquest + 533.4 # UpNDown +] + +human_scores = [ + 7127.7, # Alien + 1719.5, # Amidar + 742.0, # Assault + 8503.3, # Asterix + 753.1, # BankHeist + 37187.5, # BattleZone + 12.1, # Boxing + 30.5, # Breakout + 7387.8, # ChopperCommand + 35829.4, # CrazyClimber + 1971.0, # DemonAttack + 29.6, # Freeway + 4334.7, # Frostbite + 2412.5, # Gopher + 30826.4, # Hero + 302.8, # Jamesbond + 3035.0, # Kangaroo + 2665.5, # Krull + 22736.3, # KungFuMaster + 6951.6, # MsPacman + 14.6, # Pong + 69571.3, # PrivateEye + 13455.0, # Qbert + 7845.0, # RoadRunner + 42054.7, # Seaquest + 11693.2 # UpNDown +] + +ez_scores = [ + 808.5, # Alien + 149, # Amidar + 1263, # Assault + 25558, # Asterix + 351, # BankHeist + 13871, # BattleZone + 53, # Boxing + 414, # Breakout + 1117, # ChopperCommand + 83940, # CrazyClimber + 13004, # DemonAttack + 22, # Freeway + 296, # Frostbite + 3260, # Gopher + 9315, # Hero + 517, # Jamesbond + 724, # Kangaroo + 5663, # Krull + 30945, # KungFuMaster + 1281, # MsPacman + 20, # Pong + 97, # PrivateEye + 13782, # Qbert + 17751, # RoadRunner + 1100, # Seaquest + 17264 # UpNDown +] + +mz_scores = [ + 530.0, # Alien + 39, # Amidar + 500, # Assault + 1734, # Asterix + 193, # BankHeist + 7688, # BattleZone + 15, # Boxing + 48, # Breakout + 1350, # ChopperCommand + 56937, # CrazyClimber + 3527, # DemonAttack + 22, # Freeway + 255, # Frostbite + 1256, # Gopher + 3095, # Hero + 88, # Jamesbond + 63, # Kangaroo + 4891, # Krull + 18813, # KungFuMaster + 1266, # MsPacman + -7, # Pong + 56, # PrivateEye + 3952, # Qbert + 2500, # RoadRunner + 208, # Seaquest + 2897 # UpNDown +] + +mz_ssl_scores = [ + 700, # Alien + 90, # Amidar + 600, # Assault + 1400, # Asterix + 33, # BankHeist + 7587, # BattleZone + 20, # Boxing + 4, # Breakout + 2050, # ChopperCommand + 26060, # CrazyClimber + 4601, # DemonAttack + 12, # Freeway + 260, # Frostbite + 646, # Gopher + 9315, # Hero + 300, # Jamesbond + 600, # Kangaroo + 2700, # Krull + 25100, # KungFuMaster + 1410, # MsPacman + -15, # Pong + 100, # PrivateEye + 4700, # Qbert + 3400, # RoadRunner + 566, # Seaquest + 5213 # UpNDown +] + +unizero_scores = [ + 1000, # Alien + 96, # Amidar + 609, # Assault + 1016, # Asterix + 50, # BankHeist + 11410, # BattleZone + 7, # Boxing + 12, # Breakout + 3205, # ChopperCommand + 13666, # CrazyClimber + 1001, # DemonAttack + 7, # Freeway + 310, # Frostbite + 1153, # Gopher + 8005, # Hero + 305, # Jamesbond + 1285, # Kangaroo + 3484, # Krull + 15600, # KungFuMaster + 1927, # MsPacman + 18, # Pong + 1048, # PrivateEye + 3056, # Qbert + 11000, # RoadRunner + 620, # Seaquest + 4523 # UpNDown +] + +# Calculate normalized mean and median for each algorithm +ez_mean, ez_median = compute_normalized_mean_and_median(random_scores, human_scores, ez_scores) +mz_mean, mz_median = compute_normalized_mean_and_median(random_scores, human_scores, mz_scores) +mz_ssl_mean, mz_ssl_median = compute_normalized_mean_and_median(random_scores, human_scores, mz_ssl_scores) +unizero_mean, unizero_median = compute_normalized_mean_and_median(random_scores, human_scores, unizero_scores) + +# Print the results +print(f"EZ - Normalized Mean: {ez_mean}, Normalized Median: {ez_median}") +print(f"MZ - Normalized Mean: {mz_mean}, Normalized Median: {mz_median}") +print(f"MZ with SSL - Normalized Mean: {mz_ssl_mean}, Normalized Median: {mz_ssl_median}") +print(f"UniZero - Normalized Mean: {unizero_mean}, Normalized Median: {unizero_median}") + +# Plot the normalized means and medians for each algorithm +algorithms = ['MZ', 'MZ with SSL', 'UniZero'] +means = [mz_mean, mz_ssl_mean, unizero_mean] +medians = [mz_median, mz_ssl_median, unizero_median] + +# Save the plot as a PNG file +plot_normalized_scores(algorithms, means, medians, filename="atari100k_normalized_scores_3algo.png") \ No newline at end of file diff --git a/zoo/atari/read_atari_results_from_txt.py b/zoo/atari/read_atari_results_from_txt.py new file mode 100644 index 000000000..74788ae90 --- /dev/null +++ b/zoo/atari/read_atari_results_from_txt.py @@ -0,0 +1,162 @@ +import os + +def parse_eval_return_mean(log_file_path: str) -> float | None: + """ + Parse the eval_episode_return_mean from the evaluator log file, reading from the end of the file. + + Args: + log_file_path (str): The path to the evaluator log file. + + Returns: + float | None: The eval_episode_return_mean as a float, or None if not found. + """ + try: + with open(log_file_path, 'r') as file: + lines = file.readlines() + + # Reverse the lines to start reading from the end of the file + for i, line in enumerate(reversed(lines)): + if 'eval_episode_return_mean' in line: + prev_line = lines[-i + 1] # This gets the 'next' line (2 lines after in normal order) + parts = prev_line.split('|') + if len(parts) >= 4: + try: + return float(parts[3].strip()) + except ValueError: + print(f"[ERROR] Failed to convert {parts[3].strip()} to float in {log_file_path}.") + return None + except FileNotFoundError: + print(f"[ERROR] Log file {log_file_path} not found.") + except Exception as e: + print(f"[ERROR] An error occurred while reading {log_file_path}: {str(e)}") + + return None + + +def find_evaluator_log_files(base_path: str) -> list[str]: + """ + Find all evaluator_logger.txt files within the specified base directory. + + Args: + base_path (str): The base directory to start searching from. + + Returns: + list[str]: A list of paths to evaluator_logger.txt files. + """ + evaluator_log_paths = [] + + # Walk through the base directory recursively + for root, dirs, files in os.walk(base_path): + # Check if the current folder contains evaluator_logger.txt + if 'evaluator_logger.txt' in files: + # Construct the full path to the evaluator_logger.txt file + log_file_path = os.path.join(root, 'evaluator_logger.txt') + evaluator_log_paths.append(log_file_path) + + return evaluator_log_paths + + +def extract_game_name_from_path(log_file_path: str) -> str | None: + """ + Extract the game name from the log file path. + The game name is assumed to be the part of the path just before '_atari'. + + Args: + log_file_path (str): The path to the evaluator log file. + + Returns: + str | None: The extracted game name, or None if extraction fails. + """ + try: + parts = log_file_path.split('/') + for part in parts: + if '_atari' in part: + game_name = part.split('_atari')[0] + return game_name + except Exception as e: + print(f"[ERROR] Couldn't extract game name from {log_file_path}: {str(e)}") + return None + + +def get_eval_means_for_games(base_path: str) -> tuple[list[str], list[float | None], dict[str, float | None]]: + """ + Get the eval_episode_return_mean for all games under the base directory, along with the game names. + + Args: + base_path (str): The path to the base directory containing game logs. + + Returns: + tuple[list[str], list[float | None], dict[str, float | None]]: + - List of game names. + - List of eval_episode_return_mean values (None if not found). + - Dictionary mapping game names to eval_episode_return_mean values. + """ + game_names = [] + eval_means = [] + game_eval_dict = {} + + # Find all evaluator_logger.txt files + log_files = find_evaluator_log_files(base_path) + + # Parse each log file for eval_episode_return_mean and extract the game name + for log_file in log_files: + eval_mean = parse_eval_return_mean(log_file) + game_name = extract_game_name_from_path(log_file) + + if game_name is not None: + game_names.append(game_name) + eval_means.append(eval_mean) + game_eval_dict[game_name] = eval_mean + + return game_names, eval_means, game_eval_dict + + +def save_results_to_file(game_eval_dict: dict[str, float | None], file_path: str) -> None: + """ + Save the game names and eval means to a text file. + + Args: + game_eval_dict (dict[str, float | None]): Dictionary of game names and corresponding eval means. + file_path (str): The path to the output text file. + + Returns: + None + """ + try: + with open(file_path, 'w') as file: + file.write("Game Names and Eval Episode Return Means:\n") + file.write("=" * 50 + "\n") + for game_name, eval_mean in game_eval_dict.items(): + file.write(f"Game: {game_name}, Eval Episode Return Mean: {eval_mean}\n") + print(f"[INFO] Results saved to {file_path}.") + except Exception as e: + print(f"[ERROR] Failed to save the file: {str(e)}") + + +if __name__ == "__main__": + # You should change this to the path where your data is stored, + # and run the script in the directory like /code/LightZero/zoo/atari>. + base_path = "./config/data_muzero" + game_names, eval_means, game_eval_dict = get_eval_means_for_games(base_path) + + # Display the lists and dictionary in a more readable way + print("Game Names List:") + print("=" * 50) + for game_name in game_names: + print(game_name) + + print("\nEval Episode Return Means List:") + print("=" * 50) + for eval_mean in eval_means: + print(eval_mean) + + print("\nCombined Dictionary (game_name -> eval_episode_return_mean):") + print("=" * 50) + for game_name, eval_mean in game_eval_dict.items(): + print(f"{game_name}: {eval_mean}") + + # Option to save the results to a file + save_option = input("\nWould you like to save the results to a text file? (y/n): ").strip().lower() + if save_option == 'y': + file_path = input("Enter the desired output file path (e.g., results.txt): ").strip() + save_results_to_file(game_eval_dict, file_path) \ No newline at end of file