Skip to content

Commit c049976

Browse files
author
Chris Elion
authored
write observations directly to protobuf (#3229)
* write observations directly to protobuf * docstring and comment about Capacity
1 parent 9e7eea5 commit c049976

File tree

6 files changed

+74
-55
lines changed

6 files changed

+74
-55
lines changed

UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.IO.Abstractions.TestingHelpers;
44
using System.Reflection;
55
using MLAgents.CommunicatorObjects;
6+
using MLAgents.Sensor;
67

78
namespace MLAgents.Tests
89
{
@@ -64,7 +65,7 @@ public void TestStoreInitalize()
6465
storedVectorActions = new[] { 0f, 1f },
6566
};
6667

67-
demoStore.Record(agentInfo, new System.Collections.Generic.List<Sensor.Observation>());
68+
demoStore.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
6869
demoStore.Close();
6970
}
7071

UnitySDK/Assets/ML-Agents/Scripts/Agent.cs

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -235,13 +235,6 @@ public AgentInfo Info
235235
/// </summary>
236236
public VectorSensor collectObservationsSensor;
237237

238-
/// <summary>
239-
/// Internal buffer used for generating float observations.
240-
/// </summary>
241-
float[] m_VectorSensorBuffer;
242-
243-
WriteAdapter m_WriteAdapter = new WriteAdapter();
244-
245238
/// MonoBehaviour function that is called when the attached GameObject
246239
/// becomes enabled or active.
247240
void OnEnable()
@@ -558,8 +551,6 @@ void SendInfoToBrain()
558551
}
559552
m_Info.actionMasks = m_ActionMasker.GetMask();
560553

561-
// var param = m_PolicyFactory.brainParameters; // look, no brain params!
562-
563554
m_Info.reward = m_Reward;
564555
m_Info.done = m_Done;
565556
m_Info.maxStepReached = m_MaxStepReached;
@@ -569,19 +560,7 @@ void SendInfoToBrain()
569560

570561
if (m_Recorder != null && m_Recorder.record && Application.isEditor)
571562
{
572-
573-
if (m_VectorSensorBuffer == null)
574-
{
575-
// Create a buffer for writing uncompressed (i.e. float) sensor data to
576-
m_VectorSensorBuffer = new float[sensors.GetSensorFloatObservationSize()];
577-
}
578-
579-
// This is a bit of a hack - if we're in inference mode, observations won't be generated
580-
// But we need these to be generated for the recorder. So generate them here.
581-
var observations = new List<Observation>();
582-
GenerateSensorData(sensors, m_VectorSensorBuffer, m_WriteAdapter, observations);
583-
584-
m_Recorder.WriteExperience(m_Info, observations);
563+
m_Recorder.WriteExperience(m_Info, sensors);
585564
}
586565

587566
}

UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ public static string SanitizeName(string demoName, int maxNameLength)
7070
/// <summary>
7171
/// Forwards AgentInfo to Demonstration Store.
7272
/// </summary>
73-
public void WriteExperience(AgentInfo info, List<Observation> observations)
73+
public void WriteExperience(AgentInfo info, List<ISensor> sensors)
7474
{
75-
m_DemoStore.Record(info, observations);
75+
m_DemoStore.Record(info, sensors);
7676
}
7777

7878
public void Close()

UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ public class DemonstrationStore
2020
DemonstrationMetaData m_MetaData;
2121
Stream m_Writer;
2222
float m_CumulativeReward;
23+
WriteAdapter m_WriteAdapter = new WriteAdapter();
2324

2425
public DemonstrationStore(IFileSystem fileSystem)
2526
{
@@ -92,7 +93,7 @@ void WriteBrainParameters(string brainName, BrainParameters brainParameters)
9293
/// <summary>
9394
/// Write AgentInfo experience to file.
9495
/// </summary>
95-
public void Record(AgentInfo info, List<Observation> observations)
96+
public void Record(AgentInfo info, List<ISensor> sensors)
9697
{
9798
// Increment meta-data counters.
9899
m_MetaData.numberExperiences++;
@@ -102,8 +103,13 @@ public void Record(AgentInfo info, List<Observation> observations)
102103
EndEpisode();
103104
}
104105

105-
// Write AgentInfo to file.
106-
var agentProto = info.ToInfoActionPairProto(observations);
106+
// Generate observations and add AgentInfo to file.
107+
var agentProto = info.ToInfoActionPairProto();
108+
foreach (var sensor in sensors)
109+
{
110+
agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_WriteAdapter));
111+
}
112+
107113
agentProto.WriteDelimitedTo(m_Writer);
108114
}
109115

UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ public static class GrpcExtensions
1616
/// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto
1717
/// </summary>
1818
/// <returns>The protobuf version of the AgentInfoActionPairProto.</returns>
19-
public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai, List<Observation> observations)
19+
public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai)
2020
{
21-
var agentInfoProto = ai.ToAgentInfoProto(observations);
21+
var agentInfoProto = ai.ToAgentInfoProto();
2222

2323
var agentActionProto = new AgentActionProto
2424
{
@@ -36,7 +36,7 @@ public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai,
3636
/// Converts a AgentInfo to a protobuf generated AgentInfoProto
3737
/// </summary>
3838
/// <returns>The protobuf version of the AgentInfo.</returns>
39-
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai, List<Observation> observations)
39+
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
4040
{
4141
var agentInfoProto = new AgentInfoProto
4242
{
@@ -51,14 +51,6 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai, List<Observatio
5151
agentInfoProto.ActionMask.AddRange(ai.actionMasks);
5252
}
5353

54-
if (observations != null)
55-
{
56-
foreach (var obs in observations)
57-
{
58-
agentInfoProto.Observations.Add(obs.ToProto());
59-
}
60-
}
61-
6254
return agentInfoProto;
6355
}
6456

@@ -197,5 +189,49 @@ public static ObservationProto ToProto(this Observation obs)
197189
obsProto.Shape.AddRange(obs.Shape);
198190
return obsProto;
199191
}
192+
193+
/// <summary>
194+
/// Generate an ObservationProto for the sensor using the provided WriteAdapter.
195+
/// This is equivalent to producing an Observation and calling Observation.ToProto(),
196+
/// but avoid some intermediate memory allocations.
197+
/// </summary>
198+
/// <param name="sensor"></param>
199+
/// <param name="writeAdapter"></param>
200+
/// <returns></returns>
201+
public static ObservationProto GetObservationProto(this ISensor sensor, WriteAdapter writeAdapter)
202+
{
203+
var shape = sensor.GetObservationShape();
204+
ObservationProto observationProto = null;
205+
if (sensor.GetCompressionType() == SensorCompressionType.None)
206+
{
207+
var numFloats = sensor.ObservationSize();
208+
var floatDataProto = new ObservationProto.Types.FloatData();
209+
// Resize the float array
210+
// TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
211+
for (var i = 0; i < numFloats; i++)
212+
{
213+
floatDataProto.Data.Add(0.0f);
214+
}
215+
216+
writeAdapter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
217+
sensor.Write(writeAdapter);
218+
219+
observationProto = new ObservationProto
220+
{
221+
FloatData = floatDataProto,
222+
CompressionType = (CompressionTypeProto)SensorCompressionType.None,
223+
};
224+
}
225+
else
226+
{
227+
observationProto = new ObservationProto
228+
{
229+
CompressedData = ByteString.CopyFrom(sensor.GetCompressedObservation()),
230+
CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
231+
};
232+
}
233+
observationProto.Shape.AddRange(shape);
234+
return observationProto;
235+
}
200236
}
201237
}

UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ public struct IdCallbackPair
3535

3636
List<string> m_BehaviorNames = new List<string>();
3737
bool m_NeedCommunicateThisStep;
38-
float[] m_VectorObservationBuffer = new float[0];
39-
List<Observation> m_ObservationBuffer = new List<Observation>();
4038
WriteAdapter m_WriteAdapter = new WriteAdapter();
4139
Dictionary<string, SensorShapeValidator> m_SensorShapeValidators = new Dictionary<string, SensorShapeValidator>();
4240
Dictionary<string, List<IdCallbackPair>> m_ActionCallbacks = new Dictionary<string, List<IdCallbackPair>>();
@@ -239,18 +237,12 @@ public void DecideBatch()
239237
}
240238

241239
/// <summary>
242-
/// Sends the observations of one Agent.
240+
/// Sends the observations of one Agent.
243241
/// </summary>
244242
/// <param name="brainKey">Batch Key.</param>
245243
/// <param name="agent">Agent info.</param>
246244
public void PutObservations(string brainKey, AgentInfo info, List<ISensor> sensors, Action<AgentAction> action)
247245
{
248-
int numFloatObservations = sensors.GetSensorFloatObservationSize();
249-
if (m_VectorObservationBuffer.Length < numFloatObservations)
250-
{
251-
m_VectorObservationBuffer = new float[numFloatObservations];
252-
}
253-
254246
# if DEBUG
255247
if (!m_SensorShapeValidators.ContainsKey(brainKey))
256248
{
@@ -259,16 +251,21 @@ public void PutObservations(string brainKey, AgentInfo info, List<ISensor> senso
259251
m_SensorShapeValidators[brainKey].ValidateSensors(sensors);
260252
#endif
261253

262-
using (TimerStack.Instance.Scoped("GenerateSensorData"))
263-
{
264-
Agent.GenerateSensorData(sensors, m_VectorObservationBuffer, m_WriteAdapter, m_ObservationBuffer);
265-
}
266254
using (TimerStack.Instance.Scoped("AgentInfo.ToProto"))
267255
{
268-
var agentInfoProto = info.ToAgentInfoProto(m_ObservationBuffer);
256+
var agentInfoProto = info.ToAgentInfoProto();
257+
258+
using (TimerStack.Instance.Scoped("GenerateSensorData"))
259+
{
260+
foreach (var sensor in sensors)
261+
{
262+
var obsProto = sensor.GetObservationProto(m_WriteAdapter);
263+
agentInfoProto.Observations.Add(obsProto);
264+
}
265+
}
269266
m_CurrentUnityRlOutput.AgentInfos[brainKey].Value.Add(agentInfoProto);
270267
}
271-
m_ObservationBuffer.Clear();
268+
272269
m_NeedCommunicateThisStep = true;
273270
if (!m_ActionCallbacks.ContainsKey(brainKey))
274271
{
@@ -451,7 +448,7 @@ void UpdateSentBrainParameters(UnityRLInitializationOutputProto output)
451448
#region Handling side channels
452449

453450
/// <summary>
454-
/// Registers a side channel to the communicator. The side channel will exchange
451+
/// Registers a side channel to the communicator. The side channel will exchange
455452
/// messages with its Python equivalent.
456453
/// </summary>
457454
/// <param name="sideChannel"> The side channel to be registered.</param>

0 commit comments

Comments
 (0)