Skip to content

Commit

Permalink
fix(pu): fix render path bug in tictactoe and cartpole
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Dec 2, 2023
1 parent 545fb10 commit e763282
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 18 deletions.
5 changes: 3 additions & 2 deletions zoo/atari/envs/atari_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class AtariLightZeroEnv(BaseEnv):
render_mode_human=False,
# (bool) If True, a video of the game play is saved.
save_replay=False,
# (str) The path to save the video.
replay_path='./video',
# replay_path (str or None): The path to save the replay video. If None, the replay will not be saved.
# Only effective when env_manager.type is 'base'.
replay_path=None,
# (bool) If set to True, the game screen is converted to grayscale, reducing the complexity of the observation space.
gray_scale=True,
# (int) The number of frames to skip between each action. Higher values result in faster simulation.
Expand Down
9 changes: 3 additions & 6 deletions zoo/board_games/connect4/envs/connect4_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
from zoo.board_games.mcts_bot import MCTSBot





@ENV_REGISTRY.register('connect4')
class Connect4Env(BaseEnv):
config = dict(
Expand All @@ -62,8 +59,8 @@ class Connect4Env(BaseEnv):
# (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'.
# If None, then the game will not be rendered.
render_mode=None,
# (str) The suffix of the replay name.
replay_path='./video',
# (str or None) The directory in which to save the replay file. If None, the file is saved in the current directory.
replay_path=None,
# (str) The type of the bot of the environment.
bot_action_type='rule',
# (bool) Whether to let human to play with the agent when evaluating. If False, then use the bot to evaluate the agent.
Expand Down Expand Up @@ -479,7 +476,7 @@ def save_render_output(self, replay_name_suffix: str = '', replay_path: str = No
else:
if not os.path.exists(replay_path):
os.makedirs(replay_path)
filename = replay_path+f'/connect4_{replay_name_suffix}.{format}'
filename = replay_path + f'/connect4_{replay_name_suffix}.{format}'

if format == 'gif':
# Save frames as a GIF with a duration of 0.1 seconds per frame.
Expand Down
4 changes: 2 additions & 2 deletions zoo/board_games/gomoku/envs/gomoku_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class GomokuEnv(BaseEnv):
# (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'.
# If None, then the game will not be rendered.
render_mode=None,
# (str) The suffix of the replay name.
replay_path='./video',
# (str or None) The directory in which to save the replay file. If None, the file is saved in the current directory.
replay_path=None,
# (float) The scale of the render screen.
screen_scaling=9,
# (bool) Whether to use the 'channel last' format for the observation space. If False, 'channel first' format is used.
Expand Down
20 changes: 18 additions & 2 deletions zoo/board_games/tictactoe/envs/tictactoe_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,33 @@ def _get_done_winner_func_lru(board_tuple):

@ENV_REGISTRY.register('tictactoe')
class TicTacToeEnv(BaseEnv):

config = dict(
# env_name (str): The name of the environment.
env_name="TicTacToe",
# battle_mode (str): The mode of the battle. Choices are 'self_play_mode' or 'alpha_beta_pruning'.
battle_mode='self_play_mode',
mcts_mode='self_play_mode', # only used in AlphaZero
bot_action_type='v0', # {'v0', 'alpha_beta_pruning'}
# mcts_mode (str): The mode of Monte Carlo Tree Search. This is only used in AlphaZero.
mcts_mode='self_play_mode',
# bot_action_type (str): The type of action the bot should take. Choices are 'v0' or 'alpha_beta_pruning'.
bot_action_type='v0',
# save_replay_gif (bool): If True, the replay will be saved as a gif file.
save_replay_gif=False,
# replay_path_gif (str): The path to save the replay gif.
replay_path_gif='./replay_gif',
# agent_vs_human (bool): If True, the agent will play against a human.
agent_vs_human=False,
# prob_random_agent (int): The probability of the random agent.
prob_random_agent=0,
# prob_expert_agent (int): The probability of the expert agent.
prob_expert_agent=0,
# channel_last (bool): If True, the channel will be the last dimension.
channel_last=True,
# scale (bool): If True, the pixel values will be scaled.
scale=True,
# stop_value (int): The value to stop the game.
stop_value=1,
# alphazero_mcts_ctree (bool): If True, the Monte Carlo Tree Search from AlphaZero is used.
alphazero_mcts_ctree=False,
)

Expand Down
3 changes: 2 additions & 1 deletion zoo/box2d/lunarlander/envs/lunarlander_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class LunarLanderEnv(CartPoleEnv):
save_replay_gif=False,
# (str or None) The path to save the replay gif. If None, the replay gif will not be saved.
replay_path_gif=None,
# (str or None) The path to save the replay. If None, the replay will not be saved.
# replay_path (str or None): The path to save the replay video. If None, the replay will not be saved.
# Only effective when env_manager.type is 'base'.
replay_path=None,
# (bool) If True, the action will be scaled.
act_scale=True,
Expand Down
16 changes: 16 additions & 0 deletions zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from datetime import datetime
from typing import Union, Optional, Dict

Expand All @@ -7,6 +8,7 @@
from ding.envs import ObsPlusPrevActRewWrapper
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY
from easydict import EasyDict


@ENV_REGISTRY.register('cartpole_lightzero')
Expand All @@ -18,6 +20,20 @@ class CartPoleEnv(BaseEnv):
environment.
"""

config = dict(
# env_name (str): The name of the environment.
env_name="CartPole-v0",
# replay_path (str): The path to save the replay video. If None, the replay will not be saved.
# Only effective when env_manager.type is 'base'.
replay_path=None,
)

@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg

def __init__(self, cfg: dict = {}) -> None:
"""
Initialize the environment with a configuration dictionary. Sets up spaces for observations, actions, and rewards.
Expand Down
3 changes: 2 additions & 1 deletion zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def default_config(cls: type) -> EasyDict:
config = dict(
# (bool) Whether to use continuous action space
continuous=True,
# (str) The path to save replay videos
# replay_path (str or None): The path to save the replay video. If None, the replay will not be saved.
# Only effective when env_manager.type is 'base'.
replay_path=None,
# (bool) Whether to scale action into [-2, 2]
act_scale=True,
Expand Down
8 changes: 4 additions & 4 deletions zoo/game_2048/envs/game_2048_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(self, cfg: dict) -> None:
self.replay_format = cfg.replay_format
self.replay_name_suffix = cfg.replay_name_suffix
self.replay_path = cfg.replay_path
self.render_mode = cfg.render_mode
self.render_mode = cfg.render_mode

self.channel_last = cfg.channel_last
self.obs_type = cfg.obs_type
Expand Down Expand Up @@ -356,8 +356,8 @@ def step(self, action):
if done:
info['eval_episode_return'] = self._final_eval_reward
if self.render_mode == 'image_savefile_mode':
self.save_render_output(replay_name_suffix=self.replay_name_suffix, replay_path=self.replay_path,
format=self.replay_format)
self.save_render_output(replay_name_suffix=self.replay_name_suffix, replay_path=self.replay_path,
format=self.replay_format)

return BaseEnvTimestep(observation, reward, done, info)

Expand Down Expand Up @@ -718,7 +718,7 @@ def save_render_output(self, replay_name_suffix: str = '', replay_path=None, for
else:
if not os.path.exists(replay_path):
os.makedirs(replay_path)
filename = replay_path+f'/2048_{replay_name_suffix}.{format}'
filename = replay_path + f'/2048_{replay_name_suffix}.{format}'

if format == 'gif':
imageio.mimsave(filename, self.frames, 'GIF')
Expand Down

0 comments on commit e763282

Please sign in to comment.