Skip to content

Commit dc88653

Browse files
vincentpierreErvin T
and
Ervin T
committed
Hotfix memory leak on Python (#3664)
* Hotfix memory leak on Python * Fixing * Fixing a bug in the heuristic policy. A decision should not be requested when the agent is done * [bug-fix] Make Python able to deal with 0-step episodes (#3671) * adding some comments Co-authored-by: Ervin T <ervin@unity3d.com>
1 parent de95489 commit dc88653

File tree

6 files changed

+52
-32
lines changed

6 files changed

+52
-32
lines changed

com.unity.ml-agents/Runtime/Agent.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ protected virtual void OnDisable()
315315

316316
void NotifyAgentDone(DoneReason doneReason)
317317
{
318+
m_Info.episodeId = m_EpisodeId;
318319
m_Info.reward = m_Reward;
319320
m_Info.done = true;
320321
m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached;
@@ -376,7 +377,7 @@ public void SetModel(
376377
// If everything is the same, don't make any changes.
377378
return;
378379
}
379-
380+
NotifyAgentDone(DoneReason.Disabled);
380381
m_PolicyFactory.model = model;
381382
m_PolicyFactory.inferenceDevice = inferenceDevice;
382383
m_PolicyFactory.behaviorName = behaviorName;

com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -458,13 +458,20 @@ UnityRLInitializationOutputProto GetTempUnityRlInitializationOutput()
458458
{
459459
if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(behaviorName))
460460
{
461-
if (output == null)
461+
if (m_CurrentUnityRlOutput.AgentInfos[behaviorName].CalculateSize() > 0)
462462
{
463-
output = new UnityRLInitializationOutputProto();
464-
}
463+
// Only send the BrainParameters if there is a non empty list of
464+
// AgentInfos ready to be sent.
465+
// This is to ensure that The Python side will always have a first
466+
// observation when receiving the BrainParameters
467+
if (output == null)
468+
{
469+
output = new UnityRLInitializationOutputProto();
470+
}
465471

466-
var brainParameters = m_UnsentBrainKeys[behaviorName];
467-
output.BrainParameters.Add(brainParameters.ToProto(behaviorName, true));
472+
var brainParameters = m_UnsentBrainKeys[behaviorName];
473+
output.BrainParameters.Add(brainParameters.ToProto(behaviorName, true));
474+
}
468475
}
469476
}
470477

com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ public HeuristicPolicy(Func<float[]> heuristic)
2929
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
3030
{
3131
StepSensors(sensors);
32-
m_LastDecision = m_Heuristic.Invoke();
32+
if (!info.done)
33+
{
34+
m_LastDecision = m_Heuristic.Invoke();
35+
}
3336
}
3437

3538
/// <inheritdoc />

ml-agents/mlagents/trainers/agent_processor.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set
2+
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Any
33
from collections import defaultdict, Counter, deque
44

55
from mlagents_envs.base_env import BatchedStepResult, StepResult
@@ -66,7 +66,6 @@ def add_experiences(
6666
for _entropy in take_action_outputs["entropy"]:
6767
self.stats_reporter.add_stat("Policy/Entropy", _entropy)
6868

69-
terminated_agents: Set[str] = set()
7069
# Make unique agent_ids that are global across workers
7170
action_global_agent_ids = [
7271
get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids
@@ -85,6 +84,7 @@ def add_experiences(
8584
stored_take_action_outputs = self.last_take_action_outputs.get(
8685
global_id, None
8786
)
87+
8888
if stored_agent_step is not None and stored_take_action_outputs is not None:
8989
# We know the step is from the same worker, so use the local agent id.
9090
obs = stored_agent_step.obs
@@ -143,6 +143,8 @@ def add_experiences(
143143
traj_queue.put(trajectory)
144144
self.experience_buffers[global_id] = []
145145
if curr_agent_step.done:
146+
# Record episode length for agents which have had at least
147+
# 1 step. Done after reset ignored.
146148
self.stats_reporter.add_stat(
147149
"Environment/Cumulative Reward",
148150
self.episode_rewards.get(global_id, 0),
@@ -151,7 +153,6 @@ def add_experiences(
151153
"Environment/Episode Length",
152154
self.episode_steps.get(global_id, 0),
153155
)
154-
terminated_agents.add(global_id)
155156
elif not curr_agent_step.done:
156157
self.episode_steps[global_id] += 1
157158

@@ -160,9 +161,9 @@ def add_experiences(
160161
curr_agent_step,
161162
batched_step_result.agent_id_to_index[_id],
162163
)
163-
164-
for terminated_id in terminated_agents:
165-
self._clean_agent_data(terminated_id)
164+
# Delete all done agents, regardless of if they had a 0-length episode.
165+
if curr_agent_step.done:
166+
self._clean_agent_data(global_id)
166167

167168
for _gid in action_global_agent_ids:
168169
# If the ID doesn't have a last step result, the agent just reset,
@@ -177,14 +178,22 @@ def _clean_agent_data(self, global_id: str) -> None:
177178
"""
178179
Removes the data for an Agent.
179180
"""
180-
del self.experience_buffers[global_id]
181-
del self.last_take_action_outputs[global_id]
182-
del self.last_step_result[global_id]
183-
del self.episode_steps[global_id]
184-
del self.episode_rewards[global_id]
181+
self._safe_delete(self.experience_buffers, global_id)
182+
self._safe_delete(self.last_take_action_outputs, global_id)
183+
self._safe_delete(self.last_step_result, global_id)
184+
self._safe_delete(self.episode_steps, global_id)
185+
self._safe_delete(self.episode_rewards, global_id)
185186
self.policy.remove_previous_action([global_id])
186187
self.policy.remove_memories([global_id])
187188

189+
def _safe_delete(self, my_dictionary: Dict[Any, Any], key: Any) -> None:
190+
"""
191+
Safe removes data from a dictionary. If not found,
192+
don't delete.
193+
"""
194+
if key in my_dictionary:
195+
del my_dictionary[key]
196+
188197
def publish_trajectory_queue(
189198
self, trajectory_queue: "AgentManagerQueue[Trajectory]"
190199
) -> None:

ml-agents/mlagents/trainers/policy/tf_policy.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,17 +174,6 @@ def get_action(
174174
if batched_step_result.n_agents() == 0:
175175
return ActionInfo.empty()
176176

177-
agents_done = [
178-
agent
179-
for agent, done in zip(
180-
batched_step_result.agent_id, batched_step_result.done
181-
)
182-
if done
183-
]
184-
185-
self.remove_memories(agents_done)
186-
self.remove_previous_action(agents_done)
187-
188177
global_agent_ids = [
189178
get_global_agent_id(worker_id, int(agent_id))
190179
for agent_id in batched_step_result.agent_id
@@ -379,9 +368,11 @@ def _initialize_tensorflow_references(self):
379368

380369
def create_input_placeholders(self):
381370
with self.graph.as_default():
382-
self.global_step, self.increment_step_op, self.steps_to_increment = (
383-
ModelUtils.create_global_steps()
384-
)
371+
(
372+
self.global_step,
373+
self.increment_step_op,
374+
self.steps_to_increment,
375+
) = ModelUtils.create_global_steps()
385376
self.visual_in = ModelUtils.create_visual_input_placeholders(
386377
self.brain.camera_resolutions
387378
)

ml-agents/mlagents/trainers/tests/test_agent_processor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,15 @@ def test_agent_deletion():
152152
assert len(processor.last_take_action_outputs.keys()) == 0
153153
assert len(processor.episode_steps.keys()) == 0
154154
assert len(processor.episode_rewards.keys()) == 0
155+
assert len(processor.last_step_result.keys()) == 0
156+
157+
# check that steps with immediate dones don't add to dicts
158+
processor.add_experiences(mock_done_step, 0, ActionInfo.empty())
159+
assert len(processor.experience_buffers.keys()) == 0
160+
assert len(processor.last_take_action_outputs.keys()) == 0
161+
assert len(processor.episode_steps.keys()) == 0
162+
assert len(processor.episode_rewards.keys()) == 0
163+
assert len(processor.last_step_result.keys()) == 0
155164

156165

157166
def test_end_episode():

0 commit comments

Comments
 (0)