Skip to content

Commit

Permalink
polish(pu): polish _forward_learn() and some data process operations (#…
Browse files Browse the repository at this point in the history
…191)

* polish(pu): polish mini_infer_size, obs CHW, prepare_obs, _forward_learn, to_play deepcopy, float

* fix(pu): fix test_muzero_game_buffer.py

* polish(pu): polish entropy caculation in policy

* polish(pu): polish to_play_batch deepcopy and .float() in tree_search

* sync code

* fix(pu): fix env_type AttributeError in tests

* fix(pu): fix test_muzero_game_buffer.py

* polish(pu): polish comments
  • Loading branch information
puyuan1996 authored Mar 12, 2024
1 parent b4066ac commit dbff144
Show file tree
Hide file tree
Showing 27 changed files with 487 additions and 457 deletions.
2 changes: 1 addition & 1 deletion lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def default_config(cls: type) -> EasyDict:
# (bool) Whether to use the root value in the reanalyzing part. Please refer to EfficientZero paper for details.
use_root_value=False,
# (int) The number of samples required for mini inference.
mini_infer_size=256,
mini_infer_size=10240,
)

def __init__(self, cfg: dict):
Expand Down
8 changes: 3 additions & 5 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
beg_index = self._cfg.mini_infer_size * i
end_index = self._cfg.mini_infer_size * (i + 1)

m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device)

# calculate the target value
m_output = model.initial_inference(m_obs)
Expand All @@ -397,7 +397,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A

# concat the output slices after model inference
if self._cfg.use_root_value:
# use the root values from MCTS, as in EfficiientZero
# use the root values from MCTS, as in EfficientZero
# the root values have limited improvement but require much more GPU actors;
_, reward_pool, policy_logits_pool, latent_state_roots = concat_output(
network_output, data_type='muzero'
Expand Down Expand Up @@ -472,8 +472,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
else:
target_values.append(0)
target_rewards.append(0.0)
# TODO: check
# target_rewards.append(reward)
value_index += 1

batch_rewards.append(target_rewards)
Expand Down Expand Up @@ -527,7 +525,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
for i in range(slices):
beg_index = self._cfg.mini_infer_size * i
end_index = self._cfg.mini_infer_size * (i + 1)
m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float()
m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device)
m_output = model.initial_inference(m_obs)
if not model.training:
# if not in training, obtain the scalars of the value/reward
Expand Down
4 changes: 1 addition & 3 deletions lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
self.zero_obs_shape = config.model.observation_shape
elif len(config.model.observation_shape) == 3:
# image obs input, e.g. atari environments
self.zero_obs_shape = (
config.model.observation_shape[-2], config.model.observation_shape[-1], config.model.image_channel
)
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])

self.obs_segment = []
self.action_segment = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
frame_skip=4,
episode_life=True,
clip_rewards=True,
channel_last=True,
channel_last=False,
render_mode_human=False,
scale=True,
warp_frame=True,
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/tests/test_mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None)
support_scale=300,
categorical_distribution=True,
),
env_type='not_board_games',
)

batch_size = env_nums = policy_config.batch_size
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/tests/test_mcts_ptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
categorical_distribution=True,
support_scale=300,
),
env_type='not_board_games',
)
)

Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/tests/test_mcts_sampled_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_mcts():
action_space_size=2,
categorical_distribution=True,
),
env_type='not_board_games',
)
)

Expand Down
6 changes: 3 additions & 3 deletions lzero/mcts/tests/test_muzero_game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def test_sample_orig_data():
expected_keys = [
'collect_mcts_temperature', 'collect_epsilon', 'cur_lr', 'weighted_total_loss',
'total_loss', 'policy_loss', 'policy_entropy', 'reward_loss', 'value_loss',
'consistency_loss', 'value_priority_orig', 'value_priority', 'target_reward',
'target_value', 'transformed_target_reward', 'transformed_target_value',
'predicted_rewards', 'predicted_values', 'total_grad_norm_before_clip'
'consistency_loss', 'target_reward', 'target_value', 'transformed_target_reward',
'transformed_target_value', 'predicted_rewards', 'predicted_values',
'total_grad_norm_before_clip', 'value_priority_orig', 'value_priority',
]

# Assert that all keys are present in log_vars
Expand Down
Loading

0 comments on commit dbff144

Please sign in to comment.