Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Sep 16, 2024
2 parents cbcf42c + c76cd51 commit dd89d1c
Show file tree
Hide file tree
Showing 25 changed files with 1,315 additions and 78 deletions.
2 changes: 1 addition & 1 deletion lzero/entry/train_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def train_alphazero(
)

# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
if evaluator.should_eval(learner.train_iter) and learner.train_iter > 0:
stop, reward = evaluator.eval(
learner.save_checkpoint,
learner.train_iter,
Expand Down
9 changes: 9 additions & 0 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,15 @@ def _get_simulation_env(self):
else:
raise NotImplementedError
self.simulate_env = Connect4Env(connect4_alphazero_config.env)
elif self._cfg.simulation_env_id == 'chess':
from zoo.board_games.chess.envs.chess_lightzero_env import ChessLightZeroEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.chess.config.chess_alphazero_bot_mode_config import chess_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.chess.config.chess_alphazero_sp_mode_config import chess_alphazero_config
else:
raise NotImplementedError
self.simulate_env = ChessLightZeroEnv(chess_alphazero_config.env)
else:
raise NotImplementedError

Expand Down
114 changes: 114 additions & 0 deletions zoo/board_games/chess/config/chess_alphazero_bot_mode_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from easydict import EasyDict

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
# collector_env_num = 8
# n_episode = 8
# evaluator_env_num = 5
# num_simulations = 400
# update_per_collect = 200
# batch_size = 512
# max_env_step = int(1e6)
# mcts_ctree = False

# TODO: for debug
collector_env_num = 2
n_episode = 2
evaluator_env_num = 2
num_simulations = 4
update_per_collect = 2
batch_size = 2
max_env_step = int(1e4)
mcts_ctree = False
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

chess_alphazero_config = dict(
exp_name=f'data_az_ptree/chess_alphazero_bot-mode_ns{num_simulations}_upc{update_per_collect}_seed0',
env=dict(
board_size=8,
battle_mode='play_with_bot_mode',
channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
# ==============================================================
# for the creation of simulation env
agent_vs_human=False,
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
alphazero_mcts_ctree=mcts_ctree,
save_replay_gif=False,
replay_path_gif='./replay_gif',
# ==============================================================
),
policy=dict(
mcts_ctree=mcts_ctree,
# ==============================================================
# for the creation of simulation env
simulation_env_id='chess',
simulation_env_config_type='play_with_bot',
# ==============================================================
model=dict(
observation_shape=(8, 8, 20),
action_space_size=int(8 * 8 * 73),
# TODO: for debug
num_res_blocks=1,
num_channels=1,
fc_value_layers=[16],
fc_policy_layers=[16],
# num_res_blocks=8,
# num_channels=256,
# fc_value_layers=[256, 256],
# fc_policy_layers=[256, 256],
),
cuda=True,
board_size=8,
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='AdamW',
lr_piecewise_constant_decay=False,
learning_rate=0.0001,
grad_clip_value=0.5,
value_weight=1.0,
entropy_weight=0.01,
n_episode=n_episode,
eval_freq=int(1e3),
mcts=dict(num_simulations=num_simulations),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
),
)

chess_alphazero_config = EasyDict(chess_alphazero_config)
main_config = chess_alphazero_config

chess_alphazero_create_config = dict(
env=dict(
type='chess_lightzero',
import_names=['zoo.board_games.chess.envs.chess_lightzero_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='alphazero',
import_names=['lzero.policy.alphazero'],
),
collector=dict(
type='episode_alphazero',
import_names=['lzero.worker.alphazero_collector'],
),
evaluator=dict(
type='alphazero',
import_names=['lzero.worker.alphazero_evaluator'],
)
)
chess_alphazero_create_config = EasyDict(chess_alphazero_create_config)
create_config = chess_alphazero_create_config

if __name__ == '__main__':
from lzero.entry import train_alphazero
train_alphazero([main_config, create_config], seed=0, max_env_step=max_env_step)

This file was deleted.

113 changes: 113 additions & 0 deletions zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from easydict import EasyDict

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
n_episode = 8
evaluator_env_num = 5
num_simulations = 400
update_per_collect = 200
batch_size = 512
max_env_step = int(1e6)
mcts_ctree = True

# TODO: for debug
# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 2
# num_simulations = 4
# update_per_collect = 2
# batch_size = 2
# max_env_step = int(1e4)
# mcts_ctree = False
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
chess_alphazero_config = dict(
exp_name='data_az_ctree/chess_sp-mode_alphazero_seed0',
env=dict(
board_size=8,
battle_mode='self_play_mode',
channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
# ==============================================================
# for the creation of simulation env
agent_vs_human=False,
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
alphazero_mcts_ctree=mcts_ctree,
save_replay_gif=False,
replay_path_gif='./replay_gif',
# ==============================================================
),
policy=dict(
mcts_ctree=mcts_ctree,
# ==============================================================
# for the creation of simulation env
simulation_env_id='chess',
simulation_env_config_type='self_play',
# ==============================================================
model=dict(
observation_shape=(8, 8, 20),
action_space_size=int(8 * 8 * 73),
# TODO: for debug
num_res_blocks=1,
num_channels=1,
fc_value_layers=[16],
fc_policy_layers=[16],
# num_res_blocks=8,
# num_channels=256,
# fc_value_layers=[256, 256],
# fc_policy_layers=[256, 256],
),
cuda=True,
board_size=8,
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='AdamW',
lr_piecewise_constant_decay=False,
learning_rate=0.0001,
grad_clip_value=0.5,
value_weight=1.0,
entropy_weight=0.01,
n_episode=n_episode,
eval_freq=int(1e3),
mcts=dict(num_simulations=num_simulations),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
),
)

chess_alphazero_config = EasyDict(chess_alphazero_config)
main_config = chess_alphazero_config

chess_alphazero_create_config = dict(
env=dict(
type='chess_lightzero',
import_names=['zoo.board_games.chess.envs.chess_lightzero_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='alphazero',
import_names=['lzero.policy.alphazero'],
),
collector=dict(
type='episode_alphazero',
import_names=['lzero.worker.alphazero_collector'],
),
evaluator=dict(
type='alphazero',
import_names=['lzero.worker.alphazero_evaluator'],
)
)
chess_alphazero_create_config = EasyDict(chess_alphazero_create_config)
create_config = chess_alphazero_create_config

if __name__ == '__main__':
from lzero.entry import train_alphazero
train_alphazero([main_config, create_config], seed=0, max_env_step=max_env_step)
Loading

0 comments on commit dd89d1c

Please sign in to comment.