Skip to content

Commit

Permalink
feature(pu): add MemoryEnv, related utils and configs (#197)
Browse files Browse the repository at this point in the history
* feature(pu): add MemoryEnvLightZero and related tests

* feature(pu): add render and save_replay options in MemoryEnvLightZero

* polish(pu): delete unused files, add requirements

* feature(pu): add memory train and eval configs

* polish(pu): polish memory configs

* polish(pu): add repository links for memory_lightzero_env

* polish(pu): add lightzero env class names
  • Loading branch information
puyuan1996 authored Mar 19, 2024
1 parent c9fccf0 commit 9b0c0ae
Show file tree
Hide file tree
Showing 29 changed files with 2,042 additions and 26 deletions.
4 changes: 2 additions & 2 deletions lzero/mcts/tests/test_game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_game_segment(test_algo):
from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree
from lzero.model.efficientzero_model import EfficientZeroModel as Model
from lzero.mcts.tests.config.atari_efficientzero_config_for_test import atari_efficientzero_config as config
from zoo.atari.envs.atari_lightzero_env import AtariLightZeroEnv
envs = [AtariLightZeroEnv(config.env) for _ in range(config.env.evaluator_env_num)]
from zoo.atari.envs.atari_lightzero_env import AtariEnvLightZero
envs = [AtariEnvLightZero(config.env) for _ in range(config.env.evaluator_env_num)]

elif test_algo == 'MuZero':
from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ numpy>=1.22.4
pympler
bsuite
minigrid
moviepy
moviepy
pycolab
31 changes: 16 additions & 15 deletions zoo/README.md

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions zoo/atari/envs/atari_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@


@ENV_REGISTRY.register('atari_lightzero')
class AtariLightZeroEnv(BaseEnv):
class AtariEnvLightZero(BaseEnv):
"""
Overview:
AtariLightZeroEnv is a derived class from BaseEnv and represents the environment for the Atari LightZero game.
AtariEnvLightZero is a derived class from BaseEnv and represents the environment for the Atari LightZero game.
This class provides the necessary interfaces to interact with the environment, including reset, step, seed,
close, etc. and manages the environment's properties such as observation_space, action_space, and reward_space.
Properties:
Expand Down Expand Up @@ -79,7 +79,7 @@ def default_config(cls: type) -> EasyDict:
Overview:
Return the default configuration for the Atari LightZero environment.
Arguments:
- cls (:obj:`type`): The class AtariLightZeroEnv.
- cls (:obj:`type`): The class AtariEnvLightZero.
Returns:
- cfg (:obj:`EasyDict`): The default configuration dictionary.
"""
Expand Down
6 changes: 3 additions & 3 deletions zoo/atari/tests/test_atari_lightzero_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from zoo.atari.envs.atari_lightzero_env import AtariLightZeroEnv
from zoo.atari.envs.atari_lightzero_env import AtariEnvLightZero
from easydict import EasyDict

config = EasyDict(dict(
Expand Down Expand Up @@ -29,9 +29,9 @@
config.max_episode_steps = config.eval_max_episode_steps

@pytest.mark.envtest
class TestAtariLightZeroEnv:
class TestAtariEnvLightZero:
def test_naive(self):
env = AtariLightZeroEnv(config)
env = AtariEnvLightZero(config)
env.reset()
while True:
action = env.random_action()
Expand Down
2 changes: 1 addition & 1 deletion zoo/atari/tests/test_atari_lightzero_env_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from gym.wrappers import RecordVideo

@pytest.mark.envtest
class TestAtariLightZeroEnvVisualization:
class TestAtariEnvLightZeroVisualization:

def test_naive_env(self):
import gym, random
Expand Down
Empty file added zoo/memory/__init__.py
Empty file.
Empty file added zoo/memory/config/__init__.py
Empty file.
102 changes: 102 additions & 0 deletions zoo/memory/config/memory_efficientzero_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from easydict import EasyDict

env_id = 'key_to_door' # The name of the environment, options: 'visual_match', 'key_to_door'
memory_length = 30

max_env_step = int(1e6)

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
seed = 0
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 50
update_per_collect = 200
batch_size = 256
reanalyze_ratio = 0
td_steps = 5
policy_entropy_loss_weight = 0.
threshold_training_steps_for_final_temperature = int(5e5)
eps_greedy_exploration_in_collect = False
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

memory_efficientzero_config = dict(
exp_name=f'data_ez_ctree/{env_id}_memlen-{memory_length}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
env=dict(
stop_value=int(1e6),
env_id=env_id,
flate_observation=True, # Whether to flatten the observation
max_frames={
"explore": 15,
"distractor": memory_length,
"reward": 15
}, # Maximum frames per phase
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
),
policy=dict(
model=dict(
observation_shape=25,
action_space_size=4,
model_type='mlp',
lstm_hidden_size=128,
latent_state_dim=128,
discrete_action_encoding_type='one_hot',
norm_type='BN',
),
policy_entropy_loss_weight=policy_entropy_loss_weight,
eps=dict(
eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect,
decay=int(2e5),
),
td_steps=td_steps,
manual_temperature_decay=True,
threshold_training_steps_for_final_temperature=threshold_training_steps_for_final_temperature,
cuda=True,
env_type='not_board_games',
game_segment_length=60,
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
learning_rate=0.003,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(2e3),
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
),
)

memory_efficientzero_config = EasyDict(memory_efficientzero_config)
main_config = memory_efficientzero_config

memory_efficientzero_create_config = dict(
env=dict(
type='memory_lightzero',
import_names=['zoo.memory.envs.memory_lightzero_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='efficientzero',
import_names=['lzero.policy.efficientzero'],
),
collector=dict(
type='episode_muzero',
import_names=['lzero.worker.muzero_collector'],
)
)
memory_efficientzero_create_config = EasyDict(memory_efficientzero_create_config)
create_config = memory_efficientzero_create_config

if __name__ == "__main__":
from lzero.entry import train_muzero
train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step)
114 changes: 114 additions & 0 deletions zoo/memory/config/memory_muzero_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from easydict import EasyDict

env_id = 'key_to_door' # The name of the environment, options: 'visual_match', 'key_to_door'
memory_length = 30

max_env_step = int(10e6)

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
seed = 0
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 50
update_per_collect = 200
batch_size = 256
reanalyze_ratio = 0
td_steps = 5

# debug
# collector_env_num = 1
# n_episode = 1
# evaluator_env_num = 1
# num_simulations = 5
# update_per_collect = 2
# batch_size = 2

policy_entropy_loss_weight = 1e-4
threshold_training_steps_for_final_temperature = int(5e5)
eps_greedy_exploration_in_collect = False
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

memory_muzero_config = dict(
exp_name=f'data_mz_ctree/{env_id}_memlen-{memory_length}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_'
f'collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}'
f'_pelw{policy_entropy_loss_weight}_seed{seed}',
env=dict(
stop_value=int(1e6),
env_id=env_id,
flate_observation=True, # Whether to flatten the observation
max_frames={
"explore": 15,
"distractor": memory_length,
"reward": 15
}, # Maximum frames per phase
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
),
policy=dict(
model=dict(
observation_shape=25,
action_space_size=4,
model_type='mlp',
latent_state_dim=128,
discrete_action_encoding_type='one_hot',
norm_type='BN',
self_supervised_learning_loss=True, # NOTE: default is False.
),
eps=dict(
eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect,
decay=int(2e5),
),
policy_entropy_loss_weight=policy_entropy_loss_weight,
td_steps=td_steps,
manual_temperature_decay=True,
threshold_training_steps_for_final_temperature=threshold_training_steps_for_final_temperature,
cuda=True,
env_type='not_board_games',
game_segment_length=60,
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
learning_rate=0.003,
ssl_loss_weight=2, # NOTE: default is 0.
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(2e3),
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
),
)

memory_muzero_config = EasyDict(memory_muzero_config)
main_config = memory_muzero_config

memory_muzero_create_config = dict(
env=dict(
type='memory_lightzero',
import_names=['zoo.memory.envs.memory_lightzero_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='muzero',
import_names=['lzero.policy.muzero'],
),
collector=dict(
type='episode_muzero',
import_names=['lzero.worker.muzero_collector'],
)
)
memory_muzero_create_config = EasyDict(memory_muzero_create_config)
create_config = memory_muzero_create_config

if __name__ == "__main__":
from lzero.entry import train_muzero
train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step)
Loading

0 comments on commit 9b0c0ae

Please sign in to comment.