Skip to content

write observations directly to protobuf #3229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.IO.Abstractions.TestingHelpers;
using System.Reflection;
using MLAgents.CommunicatorObjects;
using MLAgents.Sensor;

namespace MLAgents.Tests
{
Expand Down Expand Up @@ -64,7 +65,7 @@ public void TestStoreInitalize()
storedVectorActions = new[] { 0f, 1f },
};

demoStore.Record(agentInfo, new System.Collections.Generic.List<Sensor.Observation>());
demoStore.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
demoStore.Close();
}

Expand Down
25 changes: 2 additions & 23 deletions UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,6 @@ public AgentInfo Info
/// </summary>
public VectorSensor collectObservationsSensor;

/// <summary>
/// Internal buffer used for generating float observations.
/// </summary>
float[] m_VectorSensorBuffer;

WriteAdapter m_WriteAdapter = new WriteAdapter();

/// MonoBehaviour function that is called when the attached GameObject
/// becomes enabled or active.
void OnEnable()
Expand Down Expand Up @@ -546,8 +539,6 @@ void SendInfoToBrain()
}
m_Info.actionMasks = m_ActionMasker.GetMask();

// var param = m_PolicyFactory.brainParameters; // look, no brain params!

m_Info.reward = m_Reward;
m_Info.done = m_Done;
m_Info.maxStepReached = m_MaxStepReached;
Expand All @@ -557,19 +548,7 @@ void SendInfoToBrain()

if (m_Recorder != null && m_Recorder.record && Application.isEditor)
{

if (m_VectorSensorBuffer == null)
{
// Create a buffer for writing uncompressed (i.e. float) sensor data to
m_VectorSensorBuffer = new float[sensors.GetSensorFloatObservationSize()];
}

// This is a bit of a hack - if we're in inference mode, observations won't be generated
// But we need these to be generated for the recorder. So generate them here.
var observations = new List<Observation>();
GenerateSensorData(sensors, m_VectorSensorBuffer, m_WriteAdapter, observations);

m_Recorder.WriteExperience(m_Info, observations);
m_Recorder.WriteExperience(m_Info, sensors);
}

}
Expand All @@ -592,7 +571,7 @@ void UpdateSensors()
/// <param name="buffer"> A float array that will be used as buffer when generating the observations. Must
/// be at least the same length as the total number of uncompressed floats in the observations</param>
/// <param name="adapter"> The WriteAdapter that will be used to write the ISensor data to the observations</param>
/// <param name="observations"> A list of observations outputs. This argument will be modified by this method.</param>//
/// <param name="observations"> A list of observations outputs. This argument will be modified by this method.</param>//
public static void GenerateSensorData(List<ISensor> sensors, float[] buffer, WriteAdapter adapter, List<Observation> observations)
Copy link
Contributor Author

@chriselion chriselion Jan 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not removed here because the other PR needs it, but this gets rid of the calls to GenerateSensorData

{
int floatsWritten = 0;
Expand Down
4 changes: 2 additions & 2 deletions UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ public static string SanitizeName(string demoName, int maxNameLength)
/// <summary>
/// Forwards AgentInfo to Demonstration Store.
/// </summary>
public void WriteExperience(AgentInfo info, List<Observation> observations)
public void WriteExperience(AgentInfo info, List<ISensor> sensors)
{
m_DemoStore.Record(info, observations);
m_DemoStore.Record(info, sensors);
}

public void Close()
Expand Down
12 changes: 9 additions & 3 deletions UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public class DemonstrationStore
DemonstrationMetaData m_MetaData;
Stream m_Writer;
float m_CumulativeReward;
WriteAdapter m_WriteAdapter = new WriteAdapter();

public DemonstrationStore(IFileSystem fileSystem)
{
Expand Down Expand Up @@ -92,7 +93,7 @@ void WriteBrainParameters(string brainName, BrainParameters brainParameters)
/// <summary>
/// Write AgentInfo experience to file.
/// </summary>
public void Record(AgentInfo info, List<Observation> observations)
public void Record(AgentInfo info, List<ISensor> sensors)
{
// Increment meta-data counters.
m_MetaData.numberExperiences++;
Expand All @@ -102,8 +103,13 @@ public void Record(AgentInfo info, List<Observation> observations)
EndEpisode();
}

// Write AgentInfo to file.
var agentProto = info.ToInfoActionPairProto(observations);
// Generate observations and add AgentInfo to file.
var agentProto = info.ToInfoActionPairProto();
foreach (var sensor in sensors)
{
agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_WriteAdapter));
}

agentProto.WriteDelimitedTo(m_Writer);
}

Expand Down
58 changes: 47 additions & 11 deletions UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ public static class GrpcExtensions
/// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto
/// </summary>
/// <returns>The protobuf version of the AgentInfoActionPairProto.</returns>
public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai, List<Observation> observations)
public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai)
{
var agentInfoProto = ai.ToAgentInfoProto(observations);
var agentInfoProto = ai.ToAgentInfoProto();

var agentActionProto = new AgentActionProto
{
Expand All @@ -36,7 +36,7 @@ public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai,
/// Converts a AgentInfo to a protobuf generated AgentInfoProto
/// </summary>
/// <returns>The protobuf version of the AgentInfo.</returns>
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai, List<Observation> observations)
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
{
var agentInfoProto = new AgentInfoProto
{
Expand All @@ -51,14 +51,6 @@ public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai, List<Observatio
agentInfoProto.ActionMask.AddRange(ai.actionMasks);
}

if (observations != null)
{
foreach (var obs in observations)
{
agentInfoProto.Observations.Add(obs.ToProto());
}
}

return agentInfoProto;
}

Expand Down Expand Up @@ -197,5 +189,49 @@ public static ObservationProto ToProto(this Observation obs)
obsProto.Shape.AddRange(obs.Shape);
return obsProto;
}

/// <summary>
/// Generate an ObservationProto for the sensor using the provided WriteAdapter.
/// This is equivalent to producing an Observation and calling Observation.ToProto(),
/// but avoid some intermediate memory allocations.
/// </summary>
/// <param name="sensor"></param>
/// <param name="writeAdapter"></param>
/// <returns></returns>
public static ObservationProto GetObservationProto(this ISensor sensor, WriteAdapter writeAdapter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a docstring on this ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

{
var shape = sensor.GetObservationShape();
ObservationProto observationProto = null;
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
var numFloats = sensor.ObservationSize();
var floatDataProto = new ObservationProto.Types.FloatData();
// Resize the float array
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, doesn't seem to be a way to set the Capacity on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, there is, but only in newer versions of the library: protocolbuffers/protobuf@da57400

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added code comment on this too.

// TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
for (var i = 0; i < numFloats; i++)
{
floatDataProto.Data.Add(0.0f);
}

writeAdapter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
sensor.Write(writeAdapter);

observationProto = new ObservationProto
{
FloatData = floatDataProto,
CompressionType = (CompressionTypeProto)SensorCompressionType.None,
};
}
else
{
observationProto = new ObservationProto
{
CompressedData = ByteString.CopyFrom(sensor.GetCompressedObservation()),
CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
};
}
observationProto.Shape.AddRange(shape);
return observationProto;
}
}
}
29 changes: 13 additions & 16 deletions UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ public struct IdCallbackPair

List<string> m_BehaviorNames = new List<string>();
bool m_NeedCommunicateThisStep;
float[] m_VectorObservationBuffer = new float[0];
List<Observation> m_ObservationBuffer = new List<Observation>();
WriteAdapter m_WriteAdapter = new WriteAdapter();
Dictionary<string, SensorShapeValidator> m_SensorShapeValidators = new Dictionary<string, SensorShapeValidator>();
Dictionary<string, List<IdCallbackPair>> m_ActionCallbacks = new Dictionary<string, List<IdCallbackPair>>();
Expand Down Expand Up @@ -239,18 +237,12 @@ public void DecideBatch()
}

/// <summary>
/// Sends the observations of one Agent.
/// Sends the observations of one Agent.
/// </summary>
/// <param name="brainKey">Batch Key.</param>
/// <param name="agent">Agent info.</param>
public void PutObservations(string brainKey, AgentInfo info, List<ISensor> sensors, Action<AgentAction> action)
{
int numFloatObservations = sensors.GetSensorFloatObservationSize();
if (m_VectorObservationBuffer.Length < numFloatObservations)
{
m_VectorObservationBuffer = new float[numFloatObservations];
}

# if DEBUG
if (!m_SensorShapeValidators.ContainsKey(brainKey))
{
Expand All @@ -259,16 +251,21 @@ public void PutObservations(string brainKey, AgentInfo info, List<ISensor> senso
m_SensorShapeValidators[brainKey].ValidateSensors(sensors);
#endif

using (TimerStack.Instance.Scoped("GenerateSensorData"))
{
Agent.GenerateSensorData(sensors, m_VectorObservationBuffer, m_WriteAdapter, m_ObservationBuffer);
}
using (TimerStack.Instance.Scoped("AgentInfo.ToProto"))
{
var agentInfoProto = info.ToAgentInfoProto(m_ObservationBuffer);
var agentInfoProto = info.ToAgentInfoProto();

using (TimerStack.Instance.Scoped("GenerateSensorData"))
{
foreach (var sensor in sensors)
{
var obsProto = sensor.GetObservationProto(m_WriteAdapter);
agentInfoProto.Observations.Add(obsProto);
}
}
m_CurrentUnityRlOutput.AgentInfos[brainKey].Value.Add(agentInfoProto);
}
m_ObservationBuffer.Clear();

m_NeedCommunicateThisStep = true;
if (!m_ActionCallbacks.ContainsKey(brainKey))
{
Expand Down Expand Up @@ -451,7 +448,7 @@ void UpdateSentBrainParameters(UnityRLInitializationOutputProto output)
#region Handling side channels

/// <summary>
/// Registers a side channel to the communicator. The side channel will exchange
/// Registers a side channel to the communicator. The side channel will exchange
/// messages with its Python equivalent.
/// </summary>
/// <param name="sideChannel"> The side channel to be registered.</param>
Expand Down