Skip to content

Commit 13beeac

Browse files
Hotfix 0.3.1b (#637)
* [Fix] Use the stored agent info instead of the previous agent info when bootstraping the value * [Bug Fix] Addressed #643 * [Added Line Break]
1 parent 5165e88 commit 13beeac

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

python/unitytrainers/bc/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,9 @@ def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take
185185
else:
186186
idx = stored_info_teacher.agents.index(agent_id)
187187
next_idx = next_info_teacher.agents.index(agent_id)
188-
if info_teacher.text_observations[idx] != "":
189-
info_teacher_record, info_teacher_reset = info_teacher.text_observations[idx].lower().split(",")
188+
if stored_info_teacher.text_observations[idx] != "":
189+
info_teacher_record, info_teacher_reset = \
190+
stored_info_teacher.text_observations[idx].lower().split(",")
190191
next_info_teacher_record, next_info_teacher_reset = next_info_teacher.text_observations[idx].\
191192
lower().split(",")
192193
if next_info_teacher_reset == "true":

python/unitytrainers/ppo/trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,18 +269,20 @@ def process_experiences(self, current_info: AllBrainInfo, new_info: AllBrainInfo
269269
"""
270270

271271
info = new_info[self.brain_name]
272-
last_info = current_info[self.brain_name]
273272
for l in range(len(info.agents)):
274273
agent_actions = self.training_buffer[info.agents[l]]['actions']
275274
if ((info.local_done[l] or len(agent_actions) > self.trainer_parameters['time_horizon'])
276275
and len(agent_actions) > 0):
276+
agent_id = info.agents[l]
277277
if info.local_done[l] and not info.max_reached[l]:
278278
value_next = 0.0
279279
else:
280280
if info.max_reached[l]:
281-
bootstrapping_info = last_info
281+
bootstrapping_info = self.training_buffer[agent_id].last_brain_info
282+
idx = bootstrapping_info.agents.index(agent_id)
282283
else:
283284
bootstrapping_info = info
285+
idx = l
284286
feed_dict = {self.model.batch_size: len(bootstrapping_info.vector_observations), self.model.sequence_length: 1}
285287
if self.use_observations:
286288
for i in range(len(bootstrapping_info.visual_observations)):
@@ -293,8 +295,7 @@ def process_experiences(self, current_info: AllBrainInfo, new_info: AllBrainInfo
293295
feed_dict[self.model.memory_in] = bootstrapping_info.memories
294296
if not self.is_continuous_action and self.use_recurrent:
295297
feed_dict[self.model.prev_action] = np.reshape(bootstrapping_info.previous_vector_actions, [-1])
296-
value_next = self.sess.run(self.model.value, feed_dict)[l]
297-
agent_id = info.agents[l]
298+
value_next = self.sess.run(self.model.value, feed_dict)[idx]
298299

299300
self.training_buffer[agent_id]['advantages'].set(
300301
get_gae(

0 commit comments

Comments
 (0)