Skip to content

Commit

Permalink
fix(sk): fix stochastic_muzero_model_mlp.py with chance encoder (#284)
Browse files Browse the repository at this point in the history
* fix(sk): Fix stochastic_muzero mlp model issues with chance encoder Enabled

- Use chance_space_size in dynamics_network.
- Update _dynamics method to encode action with chance_space_size.
- Use afterstate_dynamics_network instead of dynamics_network in _afterstate_dynamics method.

Signed-off-by: Shivam Kumar <kumar.shivam.jarvis@gmail.com>

* feature(sk): Modify cartpole env to add support for chance values

Added for testing chance encoder with stochastic muzero

Signed-off-by: Shivam Kumar <kumar.shivam.jarvis@gmail.com>

---------

Signed-off-by: Shivam Kumar <kumar.shivam.jarvis@gmail.com>
Co-authored-by: 蒲源 <48008469+puyuan1996@users.noreply.github.com>
  • Loading branch information
ShivamKumar2002 and puyuan1996 authored Oct 18, 2024
1 parent 0fe817e commit 761bfa9
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 26 deletions.
38 changes: 15 additions & 23 deletions lzero/model/stochastic_muzero_model_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def __init__(
# here, the input is two concatenated frames
self.chance_encoder = ChanceEncoder(observation_shape * 2, chance_space_size, encoder_backbone_type='mlp')
self.dynamics_network = DynamicsNetwork(
action_encoding_dim=self.action_encoding_dim,
num_channels=self.latent_state_dim + self.action_encoding_dim,
action_encoding_dim=self.chance_space_size,
num_channels=self.latent_state_dim + self.chance_space_size,
common_layer_num=2,
fc_reward_layers=fc_reward_layers,
output_support_size=self.reward_support_size,
Expand Down Expand Up @@ -190,27 +190,19 @@ def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[t
"""
# NOTE: the discrete action encoding type is important for some environments

# discrete action space
if self.discrete_action_encoding_type == 'one_hot':
# Stack latent_state with the one hot encoded action
if len(action.shape) == 1:
# (batch_size, ) -> (batch_size, 1)
# e.g., torch.Size([8]) -> torch.Size([8, 1])
action = action.unsqueeze(-1)
# Stack latent_state with the one hot encoded action
if len(action.shape) == 1:
# (batch_size, ) -> (batch_size, 1)
# e.g., torch.Size([8]) -> torch.Size([8, 1])
action = action.unsqueeze(-1)

# transform action to one-hot encoding.
# action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4)
action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device)
# transform action to torch.int64
action = action.long()
action_one_hot.scatter_(1, action, 1)
action_encoding = action_one_hot
elif self.discrete_action_encoding_type == 'not_one_hot':
action_encoding = action / self.action_space_size
if len(action_encoding.shape) == 1:
# (batch_size, ) -> (batch_size, 1)
# e.g., torch.Size([8]) -> torch.Size([8, 1])
action_encoding = action_encoding.unsqueeze(-1)
# transform action to one-hot encoding.
# action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4)
action_one_hot = torch.zeros(action.shape[0], self.chance_space_size, device=action.device)
# transform action to torch.int64
action = action.long()
action_one_hot.scatter_(1, action, 1)
action_encoding = action_one_hot

action_encoding = action_encoding.to(latent_state.device).float()
# state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or
Expand Down Expand Up @@ -274,7 +266,7 @@ def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor)
# (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type.
state_action_encoding = torch.cat((latent_state, action_encoding), dim=1)

next_latent_state, reward = self.dynamics_network(state_action_encoding)
next_latent_state, reward = self.afterstate_dynamics_network(state_action_encoding)

if not self.state_norm:
return next_latent_state, reward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
batch_size = 256
max_env_step = int(1e5)
reanalyze_ratio = 0
enable_chance = False
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand All @@ -20,6 +21,7 @@
env=dict(
env_id='CartPole-v0',
continuous=False,
enable_chance=enable_chance,
manually_discretization=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
Expand All @@ -30,7 +32,7 @@
model=dict(
observation_shape=4,
action_space_size=2,
chance_space_size=2,
chance_space_size=3,
model_type='mlp',
lstm_hidden_size=128,
latent_state_dim=128,
Expand All @@ -51,6 +53,7 @@
ssl_loss_weight=2, # NOTE: default is 0.
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
use_ture_chance_label_in_chance_encoder=enable_chance,
n_episode=n_episode,
eval_freq=int(2e2),
replay_buffer_size=int(1e6),
Expand Down Expand Up @@ -78,4 +81,4 @@

if __name__ == "__main__":
from lzero.entry import train_muzero
train_muzero([main_config, create_config], seed=0, model_path=main_config.policy.model_path, max_env_step=max_env_step)
train_muzero([main_config, create_config], seed=0, model_path=main_config.policy.model_path, max_env_step=max_env_step)
22 changes: 21 additions & 1 deletion zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
import os
import random
from datetime import datetime
from typing import Union, Optional, Dict

Expand All @@ -26,6 +26,11 @@ class CartPoleEnv(BaseEnv):
config = dict(
# env_id (str): The name of the CartPole environment.
env_id="CartPole-v0",
# enable_chance (bool): Whether to enable chance in observation.
# If enabled, one of the first 3 values of observation will be multiplied by 2.
# used for testing chance encoder in stochastic_muzero.
# chance space is 3.
enable_chance=False,
# save_replay_gif (bool): If True, saves the replay as a gif.
save_replay_gif=False,
# replay_path_gif (str or None): The path to save the gif replay. If None, gif will not be saved.
Expand All @@ -46,6 +51,7 @@ def __init__(self, cfg: dict = {}) -> None:
cfg (dict): Configuration dict that includes `env_id`, `save_replay_gif`, and `replay_path_gif`.
"""
self._cfg = cfg
self._enable_chance = self._cfg.get('enable_chance', False)
self._init_flag = False
self._replay_path_gif = cfg.get('replay_path_gif', None)
self._save_replay_gif = cfg.get('save_replay_gif', False)
Expand Down Expand Up @@ -85,6 +91,13 @@ def reset(self) -> Dict[str, np.ndarray]:
action_mask = np.ones(self.action_space.n, 'int8')
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}

# this is to artificially introduce randomness in order to evaluate the performance of
# stochastic_muzero on state input
if self._enable_chance:
chance_value = random.randint(0, 2)
obs['observation'][chance_value] *= 2
obs['chance'] = chance_value

return obs

def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
Expand Down Expand Up @@ -126,6 +139,13 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
action_mask = np.ones(self.action_space.n, 'int8')
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}

# this is to artificially introduce randomness in order to evaluate the performance of
# stochastic_muzero on state input
if self._enable_chance:
chance_value = random.randint(0, 2)
obs['observation'][chance_value] *= 2
obs['chance'] = chance_value

return BaseEnvTimestep(obs, rew, done, info)

def save_gif_replay(self) -> None:
Expand Down

0 comments on commit 761bfa9

Please sign in to comment.