Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def update_policy(self):
update_stats = self.policy.bc_module.update()
for stat, val in update_stats.items():
self.stats[stat].append(val)
self.training_buffer.reset_update_buffer()
self.clear_update_buffer()
self.trainer_metrics.end_policy_update()


Expand Down
7 changes: 7 additions & 0 deletions ml-agents/mlagents/trainers/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ def end_episode(self) -> None:
for agent_id in rewards:
rewards[agent_id] = 0

def clear_update_buffer(self) -> None:
"""
Clear the buffers that have been built up during inference. If
we're not training, this should be called instead of update_policy.
"""
self.training_buffer.reset_update_buffer()

def add_policy_outputs(
self, take_action_outputs: ActionInfoOutputs, agent_id: str, agent_idx: int
) -> None:
Expand Down
10 changes: 10 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import mlagents.trainers.tests.mock_brain as mb
import numpy as np
from mlagents.trainers.rl_trainer import RLTrainer
from mlagents.trainers.tests.test_buffer import construct_fake_buffer


@pytest.fixture
Expand Down Expand Up @@ -92,3 +93,12 @@ def test_rl_trainer(add_policy_outputs, add_rewards_outputs, num_vis_obs):
for rewards in trainer.collected_rewards.values():
for agent_id in rewards:
assert rewards[agent_id] == 0


def test_clear_update_buffer():
trainer = create_rl_trainer()
trainer.training_buffer = construct_fake_buffer()
trainer.training_buffer.append_update_buffer(2, batch_size=None, training_length=2)
trainer.clear_update_buffer()
for _, arr in trainer.training_buffer.update_buffer.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No is_empty() method on the buffer? :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No... maybe we should add it. Also I use this in the buffer: len(next(iter(self.update_buffer.values()))) to get its length - there probably should be a length method as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your call - I'm OK with this as it is, but I think that would be a little nicer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a separate PR, unless we decide to refactor the buffer anyways

assert len(arr) == 0
28 changes: 28 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_trainer_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,31 @@ def test_take_step_adds_experiences_to_trainer_and_trains():
)
trainer_mock.update_policy.assert_called_once()
trainer_mock.increment_step.assert_called_once()


def test_take_step_if_not_training():
tc, trainer_mock = trainer_controller_with_take_step_mocks()
tc.train_model = False

action_info_dict = {"testbrain": MagicMock()}

old_step_info = EnvironmentStep(Mock(), Mock(), action_info_dict)
new_step_info = EnvironmentStep(Mock(), Mock(), action_info_dict)
trainer_mock.is_ready_update = MagicMock(return_value=False)

env_mock = MagicMock()
env_mock.step.return_value = [new_step_info]
env_mock.reset.return_value = [old_step_info]

tc.advance(env_mock)
env_mock.reset.assert_not_called()
env_mock.step.assert_called_once()
trainer_mock.add_experiences.assert_called_once_with(
new_step_info.previous_all_brain_info,
new_step_info.current_all_brain_info,
new_step_info.brain_name_to_action_info["testbrain"].outputs,
)
trainer_mock.process_experiences.assert_called_once_with(
new_step_info.previous_all_brain_info, new_step_info.current_all_brain_info
)
trainer_mock.clear_update_buffer.assert_called_once()
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/trainer_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,7 @@ def advance(self, env: EnvManager) -> int:
with hierarchical_timer("update_policy"):
trainer.update_policy()
env.set_policy(brain_name, trainer.policy)
else:
# Avoid memory leak during inference
trainer.clear_update_buffer()
return len(new_step_infos)