Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): add agent class to support LightZero's HuggingFace Model Zoo #163

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lzero/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .muzero import MuZeroAgent
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
Empty file added lzero/agent/config/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions lzero/agent/config/muzero/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from easydict import EasyDict
from . import gym_cartpole_v0

supported_env_cfg = {
gym_cartpole_v0.cfg.main_config.env.env_id: gym_cartpole_v0.cfg,
}

supported_env_cfg = EasyDict(supported_env_cfg)
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
76 changes: 76 additions & 0 deletions lzero/agent/config/muzero/gym_cartpole_v0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from easydict import EasyDict

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 25
update_per_collect = 100
batch_size = 256
max_env_step = int(1e5)
reanalyze_ratio = 0
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

cfg = dict(
main_config=dict(
exp_name='CartPole-v0-MuZero',
seed=0,
env=dict(
env_id='CartPole-v0',
continuous=False,
manually_discretization=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
),
policy=dict(
model=dict(
observation_shape=4,
action_space_size=2,
model_type='mlp',
lstm_hidden_size=128,
latent_state_dim=128,
self_supervised_learning_loss=True, # NOTE: default is False.
discrete_action_encoding_type='one_hot',
norm_type='BN',
),
cuda=True,
env_type='not_board_games',
game_segment_length=50,
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
learning_rate=0.003,
ssl_loss_weight=2, # NOTE: default is 0.
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(2e2),
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
),
wandb_logger=dict(
gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False
),
),
create_config=dict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的config是专门用于新写的 muzero class的config吗?以后所有的config相当于都需要再这里重写一份,应该可以写一个转换函数,把LZ原来的config转换成这里需要格式的config,然后就可以避免冗余代码了吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我不太推荐写个转换函数,最好是代码里写的就是Agent load的,这样比较直接。(以后如果不用这一些config了,删了就算废弃了,不会牵扯到其它组件。)(如果写转换函数的话,我对未来的兼容性承诺保持怀疑。而且不易读。)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的哈

env=dict(
type='cartpole_lightzero',
import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='muzero',
import_names=['lzero.policy.muzero'],
),
),
)

cfg = EasyDict(cfg)
Loading