Skip to content

Commit 5d5fe57

Browse files
Ervin TChris Elion
Ervin T
authored and
Chris Elion
committed
Fix batch size issue with BC (#2965)
1 parent 2287c06 commit 5d5fe57

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

ml-agents/mlagents/trainers/bc/trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,14 @@ def update_policy(self):
122122
"""
123123
self.demonstration_buffer.update_buffer.shuffle(self.policy.sequence_length)
124124
batch_losses = []
125+
batch_size = self.n_sequences * self.policy.sequence_length
126+
# We either divide the entire buffer into num_batches batches, or limit the number
127+
# of batches to batches_per_epoch.
125128
num_batches = min(
126-
len(self.demonstration_buffer.update_buffer["actions"]) // self.n_sequences,
129+
len(self.demonstration_buffer.update_buffer["actions"]) // batch_size,
127130
self.batches_per_epoch,
128131
)
129132

130-
batch_size = self.n_sequences * self.policy.sequence_length
131-
132133
for i in range(0, num_batches * batch_size, batch_size):
133134
update_buffer = self.demonstration_buffer.update_buffer
134135
mini_batch = update_buffer.make_mini_batch(i, i + batch_size)

ml-agents/mlagents/trainers/tests/test_bc.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ def dummy_config():
2525
use_recurrent: false
2626
sequence_length: 32
2727
memory_size: 32
28-
batches_per_epoch: 1
28+
batches_per_epoch: 100 # Force code to use all possible batches
2929
batch_size: 32
3030
summary_freq: 2000
3131
max_steps: 4000
3232
"""
3333
)
3434

3535

36-
def create_bc_trainer(dummy_config, is_discrete=False):
36+
def create_bc_trainer(dummy_config, is_discrete=False, use_recurrent=False):
3737
mock_env = mock.Mock()
3838
if is_discrete:
3939
mock_brain = mb.create_mock_pushblock_brain()
@@ -54,15 +54,17 @@ def create_bc_trainer(dummy_config, is_discrete=False):
5454
trainer_parameters["demo_path"] = (
5555
os.path.dirname(os.path.abspath(__file__)) + "/test.demo"
5656
)
57+
trainer_parameters["use_recurrent"] = use_recurrent
5758
trainer = BCTrainer(
5859
mock_brain, trainer_parameters, training=True, load=False, seed=0, run_id=0
5960
)
6061
trainer.demonstration_buffer = mb.simulate_rollout(env, trainer.policy, 100)
6162
return trainer, env
6263

6364

64-
def test_bc_trainer_step(dummy_config):
65-
trainer, env = create_bc_trainer(dummy_config)
65+
@pytest.mark.parametrize("use_recurrent", [True, False])
66+
def test_bc_trainer_step(dummy_config, use_recurrent):
67+
trainer, env = create_bc_trainer(dummy_config, use_recurrent=use_recurrent)
6668
# Test get_step
6769
assert trainer.get_step == 0
6870
# Test update policy

0 commit comments

Comments
 (0)