Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def update_policy(self):
update_stats = self.policy.bc_module.update()
for stat, val in update_stats.items():
self.stats[stat].append(val)
self.training_buffer.reset_update_buffer()
self.clear_buffer()
self.trainer_metrics.end_policy_update()


Expand Down
7 changes: 7 additions & 0 deletions ml-agents/mlagents/trainers/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ def end_episode(self) -> None:
for agent_id in rewards:
rewards[agent_id] = 0

def clear_buffer(self) -> None:
"""
Clear the buffers that have been built up during inference. If
we're not training, this should be called instead of update_policy.
"""
self.training_buffer.reset_update_buffer()

def add_policy_outputs(
self, take_action_outputs: ActionInfoOutputs, agent_id: str, agent_idx: int
) -> None:
Expand Down
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/trainer_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,7 @@ def advance(self, env: EnvManager) -> int:
with hierarchical_timer("update_policy"):
trainer.update_policy()
env.set_policy(brain_name, trainer.policy)
else:
# Avoid memory leak during inference
trainer.clear_buffer()
return len(new_step_infos)