Skip to content

Commit b3ca0c9

Browse files
author
Ervin T
authored
Convert List[np.ndarray] to np.ndarray before using torch.as_tensor (#4183)
Big speedup in visual obs
1 parent a303586 commit b3ca0c9

File tree

4 files changed

+28
-16
lines changed

4 files changed

+28
-16
lines changed

ml-agents/mlagents/trainers/models_torch.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from enum import Enum
2-
from typing import Callable, NamedTuple
2+
from typing import Callable, NamedTuple, List, Optional
33

44
import torch
55
from torch import nn
6+
import numpy as np
67

78
from mlagents.trainers.distributions_torch import (
89
GaussianDistribution,
@@ -19,6 +20,16 @@
1920
EPSILON = 1e-7
2021

2122

23+
def list_to_tensor(
24+
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = None
25+
) -> torch.Tensor:
26+
"""
27+
Converts a list of numpy arrays into a tensor. MUCH faster than
28+
calling as_tensor on the list directly.
29+
"""
30+
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
31+
32+
2233
class ActionType(Enum):
2334
DISCRETE = "discrete"
2435
CONTINUOUS = "continuous"

ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from mlagents.trainers.optimizer import Optimizer
1313
from mlagents.trainers.settings import TrainerSettings, RewardSignalType
1414
from mlagents.trainers.trajectory import SplitObservations
15+
from mlagents.trainers.models_torch import list_to_tensor
1516

1617

1718
class TorchOptimizer(Optimizer): # pylint: disable=W0223
@@ -79,21 +80,21 @@ def get_value_estimates(
7980
def get_trajectory_value_estimates(
8081
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool
8182
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
82-
vector_obs = [torch.as_tensor(batch["vector_obs"])]
83+
vector_obs = [list_to_tensor(batch["vector_obs"])]
8384
if self.policy.use_vis_obs:
8485
visual_obs = []
8586
for idx, _ in enumerate(
8687
self.policy.actor_critic.network_body.visual_encoders
8788
):
88-
visual_ob = torch.as_tensor(batch["visual_obs%d" % idx])
89+
visual_ob = list_to_tensor(batch["visual_obs%d" % idx])
8990
visual_obs.append(visual_ob)
9091
else:
9192
visual_obs = []
9293

9394
memory = torch.zeros([1, len(vector_obs[0]), self.policy.m_size])
9495

9596
next_obs = np.concatenate(next_obs, axis=-1)
96-
next_obs = [torch.as_tensor(next_obs).unsqueeze(0)]
97+
next_obs = [list_to_tensor(next_obs).unsqueeze(0)]
9798
next_memory = torch.zeros([1, 1, self.policy.m_size])
9899

99100
value_estimates, mean_value = self.policy.actor_critic.critic_pass(

ml-agents/mlagents/trainers/policy/nn_policy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def __init__(
2121
seed: int,
2222
brain: BrainParameters,
2323
trainer_settings: TrainerSettings,
24-
is_training: bool,
2524
model_path: str,
2625
load: bool,
2726
tanh_squash: bool = False,

ml-agents/mlagents/trainers/ppo/optimizer_torch.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mlagents.trainers.policy.torch_policy import TorchPolicy
88
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
99
from mlagents.trainers.settings import TrainerSettings, PPOSettings
10+
from mlagents.trainers.models_torch import list_to_tensor
1011

1112

1213
class TorchPPOOptimizer(TorchOptimizer):
@@ -91,18 +92,18 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
9192
returns = {}
9293
old_values = {}
9394
for name in self.reward_signals:
94-
old_values[name] = torch.as_tensor(batch["{}_value_estimates".format(name)])
95-
returns[name] = torch.as_tensor(batch["{}_returns".format(name)])
95+
old_values[name] = list_to_tensor(batch["{}_value_estimates".format(name)])
96+
returns[name] = list_to_tensor(batch["{}_returns".format(name)])
9697

97-
vec_obs = [torch.as_tensor(batch["vector_obs"])]
98-
act_masks = torch.as_tensor(batch["action_mask"])
98+
vec_obs = [list_to_tensor(batch["vector_obs"])]
99+
act_masks = list_to_tensor(batch["action_mask"])
99100
if self.policy.use_continuous_act:
100-
actions = torch.as_tensor(batch["actions"]).unsqueeze(-1)
101+
actions = list_to_tensor(batch["actions"]).unsqueeze(-1)
101102
else:
102-
actions = torch.as_tensor(batch["actions"], dtype=torch.long)
103+
actions = list_to_tensor(batch["actions"], dtype=torch.long)
103104

104105
memories = [
105-
torch.as_tensor(batch["memory"][i])
106+
list_to_tensor(batch["memory"][i])
106107
for i in range(0, len(batch["memory"]), self.policy.sequence_length)
107108
]
108109
if len(memories) > 0:
@@ -113,7 +114,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
113114
for idx, _ in enumerate(
114115
self.policy.actor_critic.network_body.visual_encoders
115116
):
116-
vis_ob = torch.as_tensor(batch["visual_obs%d" % idx])
117+
vis_ob = list_to_tensor(batch["visual_obs%d" % idx])
117118
vis_obs.append(vis_ob)
118119
else:
119120
vis_obs = []
@@ -127,10 +128,10 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
127128
)
128129
value_loss = self.ppo_value_loss(values, old_values, returns)
129130
policy_loss = self.ppo_policy_loss(
130-
torch.as_tensor(batch["advantages"]),
131+
list_to_tensor(batch["advantages"]),
131132
log_probs,
132-
torch.as_tensor(batch["action_probs"]),
133-
torch.as_tensor(batch["masks"], dtype=torch.int32),
133+
list_to_tensor(batch["action_probs"]),
134+
list_to_tensor(batch["masks"], dtype=torch.int32),
134135
)
135136
loss = (
136137
policy_loss

0 commit comments

Comments
 (0)