Skip to content

Commit

Permalink
polish(pu): rename model_update_ratio to replay_ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Jul 25, 2024
1 parent 00f82fb commit fff7fde
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def train_unizero(
update_per_collect = cfg.policy.update_per_collect
if update_per_collect is None:
collected_transitions_num = sum(len(game_segment) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
Expand Down
4 changes: 2 additions & 2 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ class UniZeroPolicy(MuZeroPolicy):
# collect data -> update policy-> collect data -> ...
# For different env, we have different episode_length,
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor.
# If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically.
# If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically.
update_per_collect=None,
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
model_update_ratio=0.25,
replay_ratio=0.25,
# (int) Minibatch size for one gradient descent.
batch_size=256,
# (str) Optimizer for training policy network. ['SGD', 'Adam']
Expand Down

0 comments on commit fff7fde

Please sign in to comment.