Skip to content

Commit

Permalink
feature(pu): add eval_benchmark test (#296)
Browse files Browse the repository at this point in the history
* feature(pu): add eval_benchmark test

* polish(pu): polish comments and docstring in eval_benchmark

* polish(pu): polish eval_benchmark
  • Loading branch information
puyuan1996 authored Nov 9, 2024
1 parent 70b3547 commit d27f29a
Show file tree
Hide file tree
Showing 15 changed files with 280 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!')
self._save_replay_count += 1
obs = to_ndarray(obs)
rew = to_ndarray([rew])
rew = to_ndarray(rew)
return BaseEnvTimestep(obs, rew, done, info)

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion zoo/box2d/bipedalwalker/envs/bipedalwalker_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!')
self._save_replay_count += 1
obs = to_ndarray(obs)
rew = to_ndarray([rew]) # wrapped to be transferred to a array with shape (1,)
rew = to_ndarray(rew) # wrapped to be transferred to a array with shape (1,)
return BaseEnvTimestep(obs, rew, done, info)

@property
Expand Down
2 changes: 1 addition & 1 deletion zoo/box2d/lunarlander/envs/lunarlander_cont_disc_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!')
self._save_replay_count += 1
obs = to_ndarray(obs)
rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to an array with shape (1,)
rew = to_ndarray(rew).astype(np.float32) # wrapped to be transferred to an array with shape (1,)
return BaseEnvTimestep(obs, rew, done, info)

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion zoo/box2d/lunarlander/envs/lunarlander_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!')
self._save_replay_count += 1
obs = to_ndarray(obs)
rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to a array with shape (1,)
rew = to_ndarray(rew).astype(np.float32) # wrapped to be transferred to a array with shape (1,)
return BaseEnvTimestep(obs, rew, done, info)

@property
Expand Down
2 changes: 1 addition & 1 deletion zoo/bsuite/envs/bsuite_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
if obs.shape[0] == 1:
obs = obs[0]
obs = to_ndarray(obs)
rew = to_ndarray([rew]) # wrapped to be transfered to an array with shape (1,)
rew = to_ndarray(rew) # wrapped to be transfered to an array with shape (1,)

action_mask = np.ones(self.action_space.n, 'int8')
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
num_simulations = 25
update_per_collect = None
replay_ratio = 0.25
max_env_step = int(2e5)
max_env_step = int(1e5)
reanalyze_ratio = 0
batch_size = 256
num_unroll_steps = 5
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from zoo.classic_control.mountain_car.config.mtcar_muzero_config import main_config, create_config
from zoo.classic_control.mountain_car.config.mountain_car_muzero_config import main_config, create_config
from lzero.entry import eval_muzero
import numpy as np

Expand Down
4 changes: 2 additions & 2 deletions zoo/classic_control/mountain_car/envs/mtcar_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def reset(self) -> np.ndarray:

def step(self, action: np.ndarray) -> BaseEnvTimestep:
# Making sure that input action is of numpy ndarray
assert isinstance(action, np.ndarray), type(action)
# assert isinstance(action, np.ndarray), type(action)

# Extract action as int, 0-dim array
action = action.squeeze()
Expand All @@ -94,7 +94,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:

# Making sure we conform to di-engine conventions
obs = to_ndarray(obs)
rew = to_ndarray([rew]).astype(np.float32)
rew = to_ndarray(rew).astype(np.float32)
action_mask = np.ones(self.action_space.n, 'int8')
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
self._eval_episode_return += rew
obs = to_ndarray(obs).astype(np.float32)
# wrapped to be transferred to an array with shape (1,)
rew = to_ndarray([rew]).astype(np.float32)
rew = to_ndarray(rew).astype(np.float32)

if done:
info['eval_episode_return'] = self._eval_episode_return
Expand Down
2 changes: 1 addition & 1 deletion zoo/dmc2gym/envs/dmc2gym_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
obs, rew, done, info = self._env.step(action)
self._eval_episode_return += rew
obs = to_ndarray(obs).astype(np.float32)
rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to an array with shape (1,)
rew = to_ndarray(rew).astype(np.float32) # wrapped to be transferred to an array with shape (1,)

if done:
info['eval_episode_return'] = self._eval_episode_return
Expand Down
266 changes: 266 additions & 0 deletions zoo/eval_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import os
import re
import subprocess
import multiprocessing
from multiprocessing import Pool

# Define the root path of the zoo directory.
ZOO_PATH = './'

# ===== NOTE: for environments with specific configurations, you may need to add custom cases in process_algorithm() =====
# Define the threshold list for the eval_episode_return_mean values.
THRESHOLD_LIST = {
'cartpole_muzero': 200.0, # Example threshold for cartpole_muzero
'cartpole_unizero': 200.0, # Example threshold for cartpole_unizero
'atari_muzero': 18.0, # Example threshold for atari_muzero (env is Pong by default)
'atari_unizero': 18.0, # Example threshold for atari_unizero (env is Pong by default)
'dmc2gym_state_sampled_muzero': 700.0, # Example threshold for atari_unizero (env is cartpole-swingup by default)
'dmc2gym_state_sampled_unizero': 700.0, # Example threshold for atari_unizero (env is cartpole-swingup by default)

# Add more algorithms and their thresholds as needed
}

# Define the environment and algorithm list for testing.
ENV_ALGO_LIST = [
{'env': 'cartpole', 'algo': 'muzero'},
{'env': 'cartpole', 'algo': 'unizero'},
{'env': 'atari', 'algo': 'muzero'},
{'env': 'atari', 'algo': 'unizero'},
{'env': 'dmc2gym_state', 'algo': 'sampled_muzero'},
{'env': 'dmc2gym_state', 'algo': 'sampled_unizero'},
# Add more environment and algorithm pairs as needed
]

# Define the evaluator log file name to look for.
EVALUATOR_LOG_FILE = 'evaluator_logger.txt'

# Define the summary log file to store results.
SUMMARY_LOG_FILE = 'benchmark_summary.txt'


def find_config(env: str, algo: str) -> str:
"""
Recursively search for the config file in the zoo directory for the given environment and algorithm.
Args:
env (str): The environment name (e.g., 'cartpole').
algo (str): The algorithm name (e.g., 'cartpole_muzero').
Returns:
str: The path to the config file if found, otherwise None.
"""
for root, dirs, files in os.walk(ZOO_PATH):
if env in root and 'config' in dirs:
config_dir = os.path.join(root, 'config')
for file in os.listdir(config_dir):
if env + '_' + algo + '_config' in file and file.endswith('.py'):
print(f'[INFO] Found config file: {file}')
return os.path.join(config_dir, file)
return None

def run_algorithm_with_config(config_file: str) -> None:
"""
Run the algorithm using the specified config file.
Args:
config_file (str): The path to the config file.
Returns:
None
"""
# Obtain the directory and file name of the config file
config_dir = os.path.dirname(config_file)
config_filename = os.path.basename(config_file)

# Save the current working directory
original_dir = os.getcwd()

try:
# Change to the directory of the config file
os.chdir(config_dir)
# Build the command to run the algorithm
command = f"python {config_filename}"
# Run the command and capture any errors
subprocess.run(command, shell=True, check=True)
except subprocess.CalledProcessError as e:
print(f"[ERROR] Error occurred while running the algorithm: {e}")
finally:
# Change back to the original working directory
os.chdir(original_dir)

def find_evaluator_log_path(algo: str, env: str) -> str:
"""
Recursively search for the path of the 'evaluator_logger.txt' file generated during the algorithm's run,
and select the most recent directory.
NOTE: If the directory is in the format '_seed<x>_<y>', extract <y> and choose the largest value; if it's in the format '_seed<x>',
extract <x>.
Args:
algo (str): The algorithm name (e.g., 'cartpole_muzero').
env (str): The environment name (e.g., 'cartpole').
Returns:
str: The path to the 'evaluator_logger.txt' file, or None if not found.
"""
latest_number = -1
selected_log_path = None

# Regular expression to match '_seed<x>' or '_seed<x>_<y>' format
seed_pattern = re.compile(r'_seed(\d+)(?:_(\d+))?')

for root, dirs, files in os.walk(ZOO_PATH):
# Check if the directory path contains the algorithm name and environment name
if f'data_{algo}' in root and env in root:
# Look for the 'evaluator_logger.txt' file in the directory
if EVALUATOR_LOG_FILE in files:
# Find the '_seed<x>' or '_seed<x>_<y>' part in the directory and extract numbers
seed_match = seed_pattern.search(root)
if seed_match:
x_value = int(seed_match.group(1)) # Extract <x>
y_value = seed_match.group(2) # Extract <y>, may be None
if y_value:
number = int(y_value) # If <y> exists, use <y> for comparison
else:
number = x_value # If no <y>, use <x> for comparison

# Update to the latest number and record the corresponding log file path
if number > latest_number:
latest_number = number
selected_log_path = os.path.join(root, EVALUATOR_LOG_FILE)

if selected_log_path:
print(f'[INFO] Found latest evaluator log file: {selected_log_path}')
return selected_log_path
else:
print('[INFO] No evaluator log file found.')
return None


def parse_eval_return_mean(log_file_path: str) -> float:
"""
Parse the eval_episode_return_mean from the evaluator log file.
Args:
log_file_path (str): The path to the evaluator log file.
Returns:
float: The eval_episode_return_mean as a float, or None if not found.
"""
with open(log_file_path, 'r') as file:
lines = file.readlines()

for i, line in enumerate(lines):
if 'eval_episode_return_mean' in line:
if i + 2 < len(lines):
next_line = lines[i + 2]
parts = next_line.split('|')
if len(parts) >= 4:
try:
return float(parts[3].strip())
except ValueError:
print(f"[ERROR] Failed to convert {parts[3].strip()} to float.")
return None
return None


def process_algorithm(item: dict) -> tuple:
"""
Process a single environment-algorithm pair: find the config, run the algorithm, parse the log, and compare to threshold.
Args:
item (dict): A dictionary containing 'env' and 'algo'.
Returns:
tuple: A tuple with the environment, algorithm, eval return mean, threshold, and result.
"""
env = item['env']
algo = item['algo']
print(f"[INFO] Testing {algo} in {env}...")

# Step 1: Find the config file
# NOTE: for environments with specific configurations, add custom cases here
if env == 'dmc2gym_state' and algo == 'sampled_muzero':
config_file = './dmc2gym/config/dmc2gym_state_sampled_muzero_config.py'
elif env == 'dmc2gym_state' and algo == 'sampled_unizero':
config_file = './dmc2gym/config/dmc2gym_state_sampled_unizero_config.py'
else:
config_file = find_config(env, algo)

if config_file is None:
print(f"[WARNING] Config file for {algo} in {env} not found. Skipping...")
return (env, algo, 'N/A', 'N/A', 'Skipped')

# Step 2: Run the algorithm with the found config file
run_algorithm_with_config(config_file)

# Step 3: Find the evaluator log file
# NOTE: for environments with specific configurations, add custom cases here
if env == 'dmc2gym_state' and algo == 'sampled_muzero':
log_file_path = find_evaluator_log_path('sampled_muzero', 'cartpole-swingup')
elif env == 'dmc2gym_state' and algo == 'sampled_unizero':
log_file_path = find_evaluator_log_path('sampled_unizero', 'cartpole-swingup')
else:
log_file_path = find_evaluator_log_path(algo, env)

if log_file_path is None:
print(f"[WARNING] Evaluator log file for {algo} in {env} not found. Skipping...")
return (env, algo, 'N/A', 'N/A', 'Skipped')

# Step 4: Parse the evaluator log file to get eval_episode_return_mean
eval_return_mean = parse_eval_return_mean(log_file_path)
if eval_return_mean is None:
print(f"[ERROR] Failed to parse eval_episode_return_mean for {algo} in {env}.")
return (env, algo, 'N/A', 'N/A', 'Failed to parse')

# Step 5: Compare the eval_episode_return_mean with the threshold
threshold = THRESHOLD_LIST.get(env + '_' + algo, float('inf'))
result = 'Passed' if eval_return_mean > threshold else 'Failed'

print(f"[INFO] {result} for {algo} in {env}. Eval mean return: {eval_return_mean}, Threshold: {threshold}")
return (env, algo, eval_return_mean, threshold, result)


def eval_benchmark() -> None:
"""
Run the benchmark test in parallel using multiprocessing, log each result, and output a summary table.
Returns:
None
"""
# Use multiprocessing to process each environment-algorithm pair in parallel
with Pool(multiprocessing.cpu_count()) as pool:
results = pool.map(process_algorithm, ENV_ALGO_LIST)

# Split the results into passed and failed counts
passed_count = sum(1 for result in results if result[4] == 'Passed')
failed_count = sum(1 for result in results if result[4] == 'Failed')

# Print summary table
print("\n[RESULTS] Benchmark Summary Table")
print(f"{'Environment':<20}{'Algorithm':<20}{'Eval Return Mean':<20}{'Threshold':<20}{'Result'}")
for row in results:
print(f"{row[0]:<20}{row[1]:<20}{row[2]:<20}{row[3]:<20}{row[4]}")

print(f"\n[SUMMARY] Total Passed: {passed_count}, Total Failed: {failed_count}")

# Save results to a log file
with open(SUMMARY_LOG_FILE, 'w') as summary_file:
summary_file.write("[RESULTS] Benchmark Summary Table\n")
summary_file.write(f"{'Environment':<20}{'Algorithm':<20}{'Eval Return Mean':<20}{'Threshold':<20}{'Result'}\n")
for row in results:
summary_file.write(f"{row[0]:<20}{row[1]:<20}{row[2]:<20}{row[3]:<20}{row[4]}\n")
summary_file.write(f"\n[SUMMARY] Total Passed: {passed_count}, Total Failed: {failed_count}\n")


if __name__ == "__main__":
"""
This script automates the process of benchmarking LightZero algorithms across different environments by:
- Searching for algorithm configuration files,
- Running the algorithms,
- Parsing log files for key performance metrics, and
- Comparing results to predefined thresholds.
It supports [parallel] execution and generates a detailed log of the benchmarking results,
making it a useful tool for testing and evaluating different algorithms in a standardized manner.
"""
eval_benchmark()
2 changes: 1 addition & 1 deletion zoo/minigrid/envs/minigrid_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!')
self._save_replay_count += 1
obs = to_ndarray(obs)
rew = to_ndarray([rew]) # wrapped to be transferred to an array with shape (1,)
rew = to_ndarray(rew) # wrapped to be transferred to an array with shape (1,)

action_mask = np.ones(self.action_space.n, 'int8')
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
Expand Down
2 changes: 1 addition & 1 deletion zoo/mujoco/envs/mujoco_disc_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
info['eval_episode_return'] = self._eval_episode_return

obs = to_ndarray(obs).astype(np.float32)
rew = to_ndarray([rew]).astype(np.float32)
rew = to_ndarray(rew).astype(np.float32)

action_mask = np.ones(self._action_space.n, 'int8')
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
Expand Down
2 changes: 1 addition & 1 deletion zoo/mujoco/envs/mujoco_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
info['eval_episode_return'] = self._eval_episode_return

obs = to_ndarray(obs).astype(np.float32)
rew = to_ndarray([rew]).astype(np.float32)
rew = to_ndarray(rew).astype(np.float32)

action_mask = None
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
Expand Down

0 comments on commit d27f29a

Please sign in to comment.