Skip to content

Commit a64d02a

Browse files
author
Chris Elion
authored
cherry pick PR#3032 (#3066)
1 parent 164d1ab commit a64d02a

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

ml-agents/mlagents/trainers/ppo/policy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,13 @@ def get_value_estimates(
249249
]
250250
if self.use_vec_obs:
251251
feed_dict[self.model.vector_in] = [brain_info.vector_observations[idx]]
252+
agent_id = brain_info.agents[idx]
252253
if self.use_recurrent:
253-
feed_dict[self.model.memory_in] = self.retrieve_memories([idx])
254+
feed_dict[self.model.memory_in] = self.retrieve_memories([agent_id])
254255
if not self.use_continuous_act and self.use_recurrent:
255-
feed_dict[self.model.prev_action] = self.retrieve_previous_action([idx])
256+
feed_dict[self.model.prev_action] = self.retrieve_previous_action(
257+
[agent_id]
258+
)
256259
value_estimates = self.sess.run(self.model.value_heads, feed_dict)
257260

258261
value_estimates = {k: float(v) for k, v in value_estimates.items()}

ml-agents/mlagents/trainers/tf_policy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ def __init__(self, seed, brain, trainer_parameters):
5656
self.seed = seed
5757
self.brain = brain
5858
self.use_recurrent = trainer_parameters["use_recurrent"]
59-
self.memory_dict: Dict[int, np.ndarray] = {}
59+
self.memory_dict: Dict[str, np.ndarray] = {}
6060
self.num_branches = len(self.brain.vector_action_space_size)
61-
self.previous_action_dict: Dict[int, np.array] = {}
61+
self.previous_action_dict: Dict[str, np.array] = {}
6262
self.normalize = trainer_parameters.get("normalize", False)
6363
self.use_continuous_act = brain.vector_action_space_type == "continuous"
6464
if self.use_continuous_act:
@@ -181,14 +181,14 @@ def make_empty_memory(self, num_agents):
181181
return np.zeros((num_agents, self.m_size), dtype=np.float)
182182

183183
def save_memories(
184-
self, agent_ids: List[int], memory_matrix: Optional[np.ndarray]
184+
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray]
185185
) -> None:
186186
if memory_matrix is None:
187187
return
188188
for index, agent_id in enumerate(agent_ids):
189189
self.memory_dict[agent_id] = memory_matrix[index, :]
190190

191-
def retrieve_memories(self, agent_ids: List[int]) -> np.ndarray:
191+
def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray:
192192
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float)
193193
for index, agent_id in enumerate(agent_ids):
194194
if agent_id in self.memory_dict:
@@ -209,14 +209,14 @@ def make_empty_previous_action(self, num_agents):
209209
return np.zeros((num_agents, self.num_branches), dtype=np.int)
210210

211211
def save_previous_action(
212-
self, agent_ids: List[int], action_matrix: Optional[np.ndarray]
212+
self, agent_ids: List[str], action_matrix: Optional[np.ndarray]
213213
) -> None:
214214
if action_matrix is None:
215215
return
216216
for index, agent_id in enumerate(agent_ids):
217217
self.previous_action_dict[agent_id] = action_matrix[index, :]
218218

219-
def retrieve_previous_action(self, agent_ids: List[int]) -> np.ndarray:
219+
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray:
220220
action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int)
221221
for index, agent_id in enumerate(agent_ids):
222222
if agent_id in self.previous_action_dict:

0 commit comments

Comments
 (0)