Skip to content

Commit 4cb9384

Browse files
author
Chris Elion
authored
[MLA-1724] Reduce use of IEnumerable during inference (#4887)
* improve allocations in inference * Add IList overload for VectorSensor.AddObservation * [skip ci] changelog * [skip ci] migrating
1 parent f387cdf commit 4cb9384

File tree

13 files changed

+126
-62
lines changed

13 files changed

+126
-62
lines changed

Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public virtual int Write(ObservationWriter writer)
3333
float[] buffer = new float[numFloats];
3434
WriteObservation(buffer);
3535

36-
writer.AddRange(buffer);
36+
writer.AddList(buffer);
3737

3838
return numFloats;
3939
}

com.unity.ml-agents/CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,21 @@ removed when training with a player. The Editor still requires it to be clamped
2323
- Added the IHeuristicProvider interface to allow IActuators as well as Agent implement the Heuristic function to generate actions.
2424
Updated the Basic example and the Match3 Example to use Actuators.
2525
Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849)
26+
- Added `VectorSensor.AddObservation(IList<float>)`. `VectorSensor.AddObservation(IEnumerable<float>)`
27+
is deprecated. The `IList` version is recommended, as it does not generate any
28+
additional memory allocations. (#4887)
29+
- Added `ObservationWriter.AddList()` and deprecated `ObservationWriter.AddRange()`.
30+
`AddList()` is recommended, as it does not generate any additional memory allocations. (#4887)
2631

2732
#### ml-agents / ml-agents-envs / gym-unity (Python)
2833

2934
### Bug Fixes
3035
#### com.unity.ml-agents (C#)
3136
- Fix a compile warning about using an obsolete enum in `GrpcExtensions.cs`. (#4812)
3237
- CameraSensor now logs an error if the GraphicsDevice is null. (#4880)
38+
- Removed several memory allocations that happened during inference. On a test scene, this
39+
reduced the amount of memory allocated by approximately 25%. (#4887)
40+
3341
#### ml-agents / ml-agents-envs / gym-unity (Python)
3442
- Fixed a bug that would cause an exception when `RunOptions` was deserialized via `pickle`. (#4842)
3543
- Fixed a bug that can cause a crash if a behavior can appear during training in multi-environment training. (#4872)

com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ public ContinuousActionOutputApplier(ActionSpec actionSpec)
2121
m_ActionSpec = actionSpec;
2222
}
2323

24-
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
24+
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
2525
{
2626
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
2727
var agentIndex = 0;
28-
foreach (int agentId in actionIds)
28+
for (var i = 0; i < actionIds.Count; i++)
2929
{
30+
var agentId = actionIds[i];
3031
if (lastActions.ContainsKey(agentId))
3132
{
3233
var actionBuffer = lastActions[agentId];
@@ -65,7 +66,7 @@ public DiscreteActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAlloc
6566
m_ActionSpec = actionSpec;
6667
}
6768

68-
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
69+
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
6970
{
7071
//var tensorDataProbabilities = tensorProxy.Data as float[,];
7172
var idActionPairList = actionIds as List<int> ?? actionIds.ToList();
@@ -109,9 +110,11 @@ public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionar
109110
actionProbs.data.Dispose();
110111
outputTensor.data.Dispose();
111112
}
113+
112114
var agentIndex = 0;
113-
foreach (int agentId in actionIds)
115+
for (var i = 0; i < actionIds.Count; i++)
114116
{
117+
var agentId = actionIds[i];
115118
if (lastActions.ContainsKey(agentId))
116119
{
117120
var actionBuffer = lastActions[agentId];
@@ -209,12 +212,13 @@ public MemoryOutputApplier(
209212
m_Memories = memories;
210213
}
211214

212-
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
215+
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
213216
{
214217
var agentIndex = 0;
215218
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];
216-
foreach (int agentId in actionIds)
219+
for (var i = 0; i < actionIds.Count; i++)
217220
{
221+
var agentId = actionIds[i];
218222
List<float> memory;
219223
if (!m_Memories.TryGetValue(agentId, out memory)
220224
|| memory.Count < memorySize)
@@ -246,13 +250,14 @@ public BarracudaMemoryOutputApplier(
246250
m_Memories = memories;
247251
}
248252

249-
public void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
253+
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
250254
{
251255
var agentIndex = 0;
252256
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];
253257

254-
foreach (int agentId in actionIds)
258+
for (var i = 0; i < actionIds.Count; i++)
255259
{
260+
var agentId = actionIds[i];
256261
List<float> memory;
257262
if (!m_Memories.TryGetValue(agentId, out memory)
258263
|| memory.Count < memorySize * m_MemoriesCount)

com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public BiDimensionalOutputGenerator(ITensorAllocator allocator)
2121
m_Allocator = allocator;
2222
}
2323

24-
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
24+
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
2525
{
2626
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
2727
}
@@ -40,7 +40,7 @@ public BatchSizeGenerator(ITensorAllocator allocator)
4040
m_Allocator = allocator;
4141
}
4242

43-
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
43+
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
4444
{
4545
tensorProxy.data?.Dispose();
4646
tensorProxy.data = m_Allocator.Alloc(new TensorShape(1, 1));
@@ -63,7 +63,7 @@ public SequenceLengthGenerator(ITensorAllocator allocator)
6363
m_Allocator = allocator;
6464
}
6565

66-
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
66+
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
6767
{
6868
tensorProxy.shape = new long[0];
6969
tensorProxy.data?.Dispose();
@@ -92,14 +92,15 @@ public RecurrentInputGenerator(
9292
}
9393

9494
public void Generate(
95-
TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
95+
TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
9696
{
9797
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
9898

9999
var memorySize = tensorProxy.shape[tensorProxy.shape.Length - 1];
100100
var agentIndex = 0;
101-
foreach (var infoSensorPair in infos)
101+
for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++)
102102
{
103+
var infoSensorPair = infos[infoIndex];
103104
var info = infoSensorPair.agentInfo;
104105
List<float> memory;
105106

@@ -147,14 +148,15 @@ public BarracudaRecurrentInputGenerator(
147148
m_Memories = memories;
148149
}
149150

150-
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
151+
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
151152
{
152153
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
153154

154155
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];
155156
var agentIndex = 0;
156-
foreach (var infoSensorPair in infos)
157+
for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++)
157158
{
159+
var infoSensorPair = infos[infoIndex];
158160
var info = infoSensorPair.agentInfo;
159161
var offset = memorySize * m_MemoryIndex;
160162
List<float> memory;
@@ -200,14 +202,15 @@ public PreviousActionInputGenerator(ITensorAllocator allocator)
200202
m_Allocator = allocator;
201203
}
202204

203-
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
205+
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
204206
{
205207
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
206208

207209
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
208210
var agentIndex = 0;
209-
foreach (var infoSensorPair in infos)
211+
for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++)
210212
{
213+
var infoSensorPair = infos[infoIndex];
211214
var info = infoSensorPair.agentInfo;
212215
var pastAction = info.storedActions.DiscreteActions;
213216
if (!pastAction.IsEmpty())
@@ -238,14 +241,15 @@ public ActionMaskInputGenerator(ITensorAllocator allocator)
238241
m_Allocator = allocator;
239242
}
240243

241-
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
244+
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
242245
{
243246
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
244247

245248
var maskSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
246249
var agentIndex = 0;
247-
foreach (var infoSensorPair in infos)
250+
for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++)
248251
{
252+
var infoSensorPair = infos[infoIndex];
249253
var agentInfo = infoSensorPair.agentInfo;
250254
var maskList = agentInfo.discreteActionMasks;
251255
for (var j = 0; j < maskSize; j++)
@@ -274,7 +278,7 @@ public RandomNormalInputGenerator(int seed, ITensorAllocator allocator)
274278
m_Allocator = allocator;
275279
}
276280

277-
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
281+
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
278282
{
279283
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
280284
TensorUtils.FillTensorWithRandomNormal(tensorProxy, m_RandomNormal);
@@ -303,12 +307,13 @@ public void AddSensorIndex(int sensorIndex)
303307
m_SensorIndices.Add(sensorIndex);
304308
}
305309

306-
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
310+
public void Generate(TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos)
307311
{
308312
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
309313
var agentIndex = 0;
310-
foreach (var info in infos)
314+
for (var infoIndex = 0; infoIndex < infos.Count; infoIndex++)
311315
{
316+
var info = infos[infoIndex];
312317
if (info.agentInfo.done)
313318
{
314319
// If the agent is done, we might have a stale reference to the sensors
@@ -320,8 +325,9 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentIn
320325
{
321326
var tensorOffset = 0;
322327
// Write each sensor consecutively to the tensor
323-
foreach (var sensorIndex in m_SensorIndices)
328+
for (var sensorIndexIndex = 0; sensorIndexIndex < m_SensorIndices.Count; sensorIndexIndex++)
324329
{
330+
var sensorIndex = m_SensorIndices[sensorIndexIndex];
325331
var sensor = info.sensors[sensorIndex];
326332
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, tensorOffset);
327333
var numWritten = sensor.Write(m_ObservationWriter);

com.unity.ml-agents/Runtime/Inference/ModelRunner.cs

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@ internal class ModelRunner
2424
TensorApplier m_TensorApplier;
2525

2626
NNModel m_Model;
27+
string m_ModelName;
2728
InferenceDevice m_InferenceDevice;
2829
IWorker m_Engine;
2930
bool m_Verbose = false;
3031
string[] m_OutputNames;
3132
IReadOnlyList<TensorProxy> m_InferenceInputs;
32-
IReadOnlyList<TensorProxy> m_InferenceOutputs;
33+
List<TensorProxy> m_InferenceOutputs;
34+
Dictionary<string, Tensor> m_InputsByName;
3335
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>();
3436

3537
SensorShapeValidator m_SensorShapeValidator = new SensorShapeValidator();
@@ -56,6 +58,7 @@ public ModelRunner(
5658
{
5759
Model barracudaModel;
5860
m_Model = model;
61+
m_ModelName = model.name;
5962
m_InferenceDevice = inferenceDevice;
6063
m_TensorAllocator = new TensorCachingAllocator();
6164
if (model != null)
@@ -84,6 +87,8 @@ public ModelRunner(
8487
seed, m_TensorAllocator, m_Memories, barracudaModel);
8588
m_TensorApplier = new TensorApplier(
8689
actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel);
90+
m_InputsByName = new Dictionary<string, Tensor>();
91+
m_InferenceOutputs = new List<TensorProxy>();
8792
}
8893

8994
public InferenceDevice InferenceDevice
@@ -96,15 +101,14 @@ public NNModel Model
96101
get { return m_Model; }
97102
}
98103

99-
static Dictionary<string, Tensor> PrepareBarracudaInputs(IEnumerable<TensorProxy> infInputs)
104+
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs)
100105
{
101-
var inputs = new Dictionary<string, Tensor>();
102-
foreach (var inp in infInputs)
106+
m_InputsByName.Clear();
107+
for (var i = 0; i < infInputs.Count; i++)
103108
{
104-
inputs[inp.name] = inp.data;
109+
var inp = infInputs[i];
110+
m_InputsByName[inp.name] = inp.data;
105111
}
106-
107-
return inputs;
108112
}
109113

110114
public void Dispose()
@@ -114,16 +118,14 @@ public void Dispose()
114118
m_TensorAllocator?.Reset(false);
115119
}
116120

117-
List<TensorProxy> FetchBarracudaOutputs(string[] names)
121+
void FetchBarracudaOutputs(string[] names)
118122
{
119-
var outputs = new List<TensorProxy>();
123+
m_InferenceOutputs.Clear();
120124
foreach (var n in names)
121125
{
122126
var output = m_Engine.PeekOutput(n);
123-
outputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n));
127+
m_InferenceOutputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n));
124128
}
125-
126-
return outputs;
127129
}
128130

129131
public void PutObservations(AgentInfo info, List<ISensor> sensors)
@@ -169,31 +171,33 @@ public void DecideBatch()
169171
}
170172

171173
Profiler.BeginSample("ModelRunner.DecideAction");
174+
Profiler.BeginSample(m_ModelName);
172175

173-
Profiler.BeginSample($"MLAgents.{m_Model.name}.GenerateTensors");
176+
Profiler.BeginSample($"GenerateTensors");
174177
// Prepare the input tensors to be feed into the engine
175178
m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Infos);
176179
Profiler.EndSample();
177180

178-
Profiler.BeginSample($"MLAgents.{m_Model.name}.PrepareBarracudaInputs");
179-
var inputs = PrepareBarracudaInputs(m_InferenceInputs);
181+
Profiler.BeginSample($"PrepareBarracudaInputs");
182+
PrepareBarracudaInputs(m_InferenceInputs);
180183
Profiler.EndSample();
181184

182185
// Execute the Model
183-
Profiler.BeginSample($"MLAgents.{m_Model.name}.ExecuteGraph");
184-
m_Engine.Execute(inputs);
186+
Profiler.BeginSample($"ExecuteGraph");
187+
m_Engine.Execute(m_InputsByName);
185188
Profiler.EndSample();
186189

187-
Profiler.BeginSample($"MLAgents.{m_Model.name}.FetchBarracudaOutputs");
188-
m_InferenceOutputs = FetchBarracudaOutputs(m_OutputNames);
190+
Profiler.BeginSample($"FetchBarracudaOutputs");
191+
FetchBarracudaOutputs(m_OutputNames);
189192
Profiler.EndSample();
190193

191-
Profiler.BeginSample($"MLAgents.{m_Model.name}.ApplyTensors");
194+
Profiler.BeginSample($"ApplyTensors");
192195
// Update the outputs
193196
m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived);
194197
Profiler.EndSample();
195198

196-
Profiler.EndSample();
199+
Profiler.EndSample(); // end name
200+
Profiler.EndSample(); // end ModelRunner.DecideAction
197201

198202
m_Infos.Clear();
199203

com.unity.ml-agents/Runtime/Inference/TensorApplier.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public interface IApplier
3131
/// </param>
3232
/// <param name="actionIds"> List of Agents Ids that will be updated using the tensor's data</param>
3333
/// <param name="lastActions"> Dictionary of AgentId to Actions to be updated</param>
34-
void Apply(TensorProxy tensorProxy, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions);
34+
void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions);
3535
}
3636

3737
readonly Dictionary<string, IApplier> m_Dict = new Dictionary<string, IApplier>();
@@ -90,10 +90,11 @@ public TensorApplier(
9090
/// <exception cref="UnityAgentsException"> One of the tensor does not have an
9191
/// associated applier.</exception>
9292
public void ApplyTensors(
93-
IEnumerable<TensorProxy> tensors, IEnumerable<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
93+
IReadOnlyList<TensorProxy> tensors, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
9494
{
95-
foreach (var tensor in tensors)
95+
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++)
9696
{
97+
var tensor = tensors[tensorIndex];
9798
if (!m_Dict.ContainsKey(tensor.name))
9899
{
99100
throw new UnityAgentsException(

com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public interface IGenerator
3131
/// the tensor's data.
3232
/// </param>
3333
void Generate(
34-
TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos);
34+
TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos);
3535
}
3636

3737
readonly Dictionary<string, IGenerator> m_Dict = new Dictionary<string, IGenerator>();
@@ -149,10 +149,11 @@ public void InitializeObservations(List<ISensor> sensors, ITensorAllocator alloc
149149
/// <exception cref="UnityAgentsException"> One of the tensor does not have an
150150
/// associated generator.</exception>
151151
public void GenerateTensors(
152-
IEnumerable<TensorProxy> tensors, int currentBatchSize, IEnumerable<AgentInfoSensorsPair> infos)
152+
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<AgentInfoSensorsPair> infos)
153153
{
154-
foreach (var tensor in tensors)
154+
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++)
155155
{
156+
var tensor = tensors[tensorIndex];
156157
if (!m_Dict.ContainsKey(tensor.name))
157158
{
158159
throw new UnityAgentsException(

0 commit comments

Comments
 (0)