Skip to content

Commit e5395a3

Browse files
author
Ervin T
authored
[perf] Optimizations for performance (#5192)
* Lazy init the buffer when sampling * Update references rather than copy data * Don't create unneeded numpy arrays * Remove self[key] from loop
1 parent 1c7e64c commit e5395a3

File tree

3 files changed

+25
-20
lines changed

3 files changed

+25
-20
lines changed

ml-agents/mlagents/trainers/buffer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,10 +394,11 @@ def shuffle(
394394
s = np.arange(len(self[key_list[0]]) // sequence_length)
395395
np.random.shuffle(s)
396396
for key in key_list:
397+
buffer_field = self[key]
397398
tmp: List[np.ndarray] = []
398399
for i in s:
399-
tmp += self[key][i * sequence_length : (i + 1) * sequence_length]
400-
self[key][:] = tmp
400+
tmp += buffer_field[i * sequence_length : (i + 1) * sequence_length]
401+
buffer_field.set(tmp)
401402

402403
def make_mini_batch(self, start: int, end: int) -> "AgentBuffer":
403404
"""
@@ -430,7 +431,8 @@ def sample_mini_batch(
430431
* sequence_length
431432
) # Sample random sequence starts
432433
for key in self:
433-
mb_list = [self[key][i : i + sequence_length] for i in start_idxes]
434+
buffer_field = self[key]
435+
mb_list = (buffer_field[i : i + sequence_length] for i in start_idxes)
434436
# See comparison of ways to make a list from a list of lists here:
435437
# https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-list-of-lists
436438
mini_batch[key].set(list(itertools.chain.from_iterable(mb_list)))

ml-agents/mlagents/trainers/torch/encoders.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,23 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
2323
return normalized_state
2424

2525
def update(self, vector_input: torch.Tensor) -> None:
26-
steps_increment = vector_input.size()[0]
27-
total_new_steps = self.normalization_steps + steps_increment
28-
29-
input_to_old_mean = vector_input - self.running_mean
30-
new_mean = self.running_mean + (input_to_old_mean / total_new_steps).sum(0)
31-
32-
input_to_new_mean = vector_input - new_mean
33-
new_variance = self.running_variance + (
34-
input_to_new_mean * input_to_old_mean
35-
).sum(0)
36-
# Update in-place
37-
self.running_mean.data.copy_(new_mean.data)
38-
self.running_variance.data.copy_(new_variance.data)
39-
self.normalization_steps.data.copy_(total_new_steps.data)
26+
with torch.no_grad():
27+
steps_increment = vector_input.size()[0]
28+
total_new_steps = self.normalization_steps + steps_increment
29+
30+
input_to_old_mean = vector_input - self.running_mean
31+
new_mean: torch.Tensor = self.running_mean + (
32+
input_to_old_mean / total_new_steps
33+
).sum(0)
34+
35+
input_to_new_mean = vector_input - new_mean
36+
new_variance = self.running_variance + (
37+
input_to_new_mean * input_to_old_mean
38+
).sum(0)
39+
# Update references. This is much faster than in-place data update.
40+
self.running_mean: torch.Tensor = new_mean
41+
self.running_variance: torch.Tensor = new_variance
42+
self.normalization_steps: torch.Tensor = total_new_steps
4043

4144
def copy_from(self, other_normalizer: "Normalizer") -> None:
4245
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data)

ml-agents/mlagents/trainers/trajectory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,13 @@ def to_agentbuffer(self) -> AgentBuffer:
246246
exp.action.discrete
247247
)
248248

249-
cont_next_actions = np.zeros_like(exp.action.continuous)
250-
disc_next_actions = np.zeros_like(exp.action.discrete)
251-
252249
if not is_last_step:
253250
next_action = self.steps[step + 1].action
254251
cont_next_actions = next_action.continuous
255252
disc_next_actions = next_action.discrete
253+
else:
254+
cont_next_actions = np.zeros_like(exp.action.continuous)
255+
disc_next_actions = np.zeros_like(exp.action.discrete)
256256

257257
agent_buffer_trajectory[BufferKey.NEXT_CONT_ACTION].append(
258258
cont_next_actions

0 commit comments

Comments
 (0)