Skip to content

Commit cc74f81

Browse files
author
Chris Elion
authored
UI for Ray stacks, rename WriteAdapter to ObservationWriter (#3834)
* UI for Ray stacks, rename WriteAdapter to ObservationWriter * move test * changelog and migration
1 parent b73d4dd commit cc74f81

29 files changed

+87
-82
lines changed

Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ public abstract class SensorBase : ISensor
2222

2323
/// <summary>
2424
/// Default implementation of Write interface. This creates a temporary array,
25-
/// calls WriteObservation, and then writes the results to the WriteAdapter.
25+
/// calls WriteObservation, and then writes the results to the ObservationWriter.
2626
/// </summary>
27-
/// <param name="adapter"></param>
27+
/// <param name="writer"></param>
2828
/// <returns>The number of elements written.</returns>
29-
public virtual int Write(WriteAdapter adapter)
29+
public virtual int Write(ObservationWriter writer)
3030
{
3131
// TODO reuse buffer for similar agents, don't call GetObservationShape()
3232
var numFloats = this.ObservationSize();
3333
float[] buffer = new float[numFloats];
3434
WriteObservation(buffer);
3535

36-
adapter.AddRange(buffer);
36+
writer.AddRange(buffer);
3737

3838
return numFloats;
3939
}

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ and this project adheres to
7878
added in `CollectObservations()`. (#3825)
7979
- Model updates can now happen asynchronously with environment steps for better performance. (#3690)
8080
- `num_updates` and `train_interval` for SAC were replaced with `steps_per_update`. (#3690)
81+
- `WriteAdapter` was renamed to `ObservationWriter`. If you have a custom `ISensor` implementation,
82+
you will need to change the signature of its `Write()` method. (#3834)
8183

8284
### Bug Fixes
8385

com.unity.ml-agents/Editor/RayPerceptionSensorComponentBaseEditor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ protected void OnRayPerceptionInspectorGUI(bool is3d)
3737
// it is not editable during play mode.
3838
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
3939
{
40-
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
40+
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), new GUIContent("Stacked Raycasts"), true);
4141
}
4242
EditorGUI.EndDisabledGroup();
4343

com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,14 @@ public static ObservationProto ToProto(this Observation obs)
227227
}
228228

229229
/// <summary>
230-
/// Generate an ObservationProto for the sensor using the provided WriteAdapter.
230+
/// Generate an ObservationProto for the sensor using the provided ObservationWriter.
231231
/// This is equivalent to producing an Observation and calling Observation.ToProto(),
232232
/// but avoid some intermediate memory allocations.
233233
/// </summary>
234234
/// <param name="sensor"></param>
235-
/// <param name="writeAdapter"></param>
235+
/// <param name="observationWriter"></param>
236236
/// <returns></returns>
237-
public static ObservationProto GetObservationProto(this ISensor sensor, WriteAdapter writeAdapter)
237+
public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
238238
{
239239
var shape = sensor.GetObservationShape();
240240
ObservationProto observationProto = null;
@@ -249,8 +249,8 @@ public static ObservationProto GetObservationProto(this ISensor sensor, WriteAda
249249
floatDataProto.Data.Add(0.0f);
250250
}
251251

252-
writeAdapter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
253-
sensor.Write(writeAdapter);
252+
observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
253+
sensor.Write(observationWriter);
254254

255255
observationProto = new ObservationProto
256256
{

com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ internal class RpcCommunicator : ICommunicator
2727

2828
List<string> m_BehaviorNames = new List<string>();
2929
bool m_NeedCommunicateThisStep;
30-
WriteAdapter m_WriteAdapter = new WriteAdapter();
30+
ObservationWriter m_ObservationWriter = new ObservationWriter();
3131
Dictionary<string, SensorShapeValidator> m_SensorShapeValidators = new Dictionary<string, SensorShapeValidator>();
3232
Dictionary<string, List<int>> m_OrderedAgentsRequestingDecisions = new Dictionary<string, List<int>>();
3333

@@ -322,7 +322,7 @@ public void PutObservations(string behaviorName, AgentInfo info, List<ISensor> s
322322
{
323323
foreach (var sensor in sensors)
324324
{
325-
var obsProto = sensor.GetObservationProto(m_WriteAdapter);
325+
var obsProto = sensor.GetObservationProto(m_ObservationWriter);
326326
agentInfoProto.Observations.Add(obsProto);
327327
}
328328
}

com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public class DemonstrationWriter
2020
DemonstrationMetaData m_MetaData;
2121
Stream m_Writer;
2222
float m_CumulativeReward;
23-
WriteAdapter m_WriteAdapter = new WriteAdapter();
23+
ObservationWriter m_ObservationWriter = new ObservationWriter();
2424

2525
/// <summary>
2626
/// Create a DemonstrationWriter that will write to the specified stream.
@@ -117,7 +117,7 @@ internal void Record(AgentInfo info, List<ISensor> sensors)
117117
var agentProto = info.ToInfoActionPairProto();
118118
foreach (var sensor in sensors)
119119
{
120-
agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_WriteAdapter));
120+
agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_ObservationWriter));
121121
}
122122

123123
agentProto.WriteDelimitedTo(m_Writer);

com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ internal class VectorObservationGenerator : TensorGenerator.IGenerator
8282
{
8383
readonly ITensorAllocator m_Allocator;
8484
List<int> m_SensorIndices = new List<int>();
85-
WriteAdapter m_WriteAdapter = new WriteAdapter();
85+
ObservationWriter m_ObservationWriter = new ObservationWriter();
8686

8787
public VectorObservationGenerator(ITensorAllocator allocator)
8888
{
@@ -115,8 +115,8 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentIn
115115
foreach (var sensorIndex in m_SensorIndices)
116116
{
117117
var sensor = info.sensors[sensorIndex];
118-
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, tensorOffset);
119-
var numWritten = sensor.Write(m_WriteAdapter);
118+
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, tensorOffset);
119+
var numWritten = sensor.Write(m_ObservationWriter);
120120
tensorOffset += numWritten;
121121
}
122122
Debug.AssertFormat(
@@ -350,7 +350,7 @@ internal class VisualObservationInputGenerator : TensorGenerator.IGenerator
350350
{
351351
readonly int m_SensorIndex;
352352
readonly ITensorAllocator m_Allocator;
353-
WriteAdapter m_WriteAdapter = new WriteAdapter();
353+
ObservationWriter m_ObservationWriter = new ObservationWriter();
354354

355355
public VisualObservationInputGenerator(
356356
int sensorIndex, ITensorAllocator allocator)
@@ -375,8 +375,8 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentIn
375375
}
376376
else
377377
{
378-
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, 0);
379-
sensor.Write(m_WriteAdapter);
378+
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, 0);
379+
sensor.Write(m_ObservationWriter);
380380
}
381381
agentIndex++;
382382
}

com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ internal class HeuristicPolicy : IPolicy
1919
bool m_Done;
2020
bool m_DecisionRequested;
2121

22-
WriteAdapter m_WriteAdapter = new WriteAdapter();
22+
ObservationWriter m_ObservationWriter = new ObservationWriter();
2323
NullList m_NullList = new NullList();
2424

2525

@@ -128,8 +128,8 @@ void StepSensors(List<ISensor> sensors)
128128
{
129129
if (sensor.GetCompressionType() == SensorCompressionType.None)
130130
{
131-
m_WriteAdapter.SetTarget(m_NullList, sensor.GetObservationShape(), 0);
132-
sensor.Write(m_WriteAdapter);
131+
m_ObservationWriter.SetTarget(m_NullList, sensor.GetObservationShape(), 0);
132+
sensor.Write(m_ObservationWriter);
133133
}
134134
else
135135
{

com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,16 @@ public byte[] GetCompressedObservation()
9191
}
9292

9393
/// <summary>
94-
/// Writes out the generated, uncompressed image to the provided <see cref="WriteAdapter"/>.
94+
/// Writes out the generated, uncompressed image to the provided <see cref="ObservationWriter"/>.
9595
/// </summary>
96-
/// <param name="adapter">Where the observation is written to.</param>
96+
/// <param name="writer">Where the observation is written to.</param>
9797
/// <returns></returns>
98-
public int Write(WriteAdapter adapter)
98+
public int Write(ObservationWriter writer)
9999
{
100100
using (TimerStack.Instance.Scoped("CameraSensor.WriteToTensor"))
101101
{
102102
var texture = ObservationToTexture(m_Camera, m_Width, m_Height);
103-
var numWritten = Utilities.TextureToTensorProxy(texture, adapter, m_Grayscale);
103+
var numWritten = Utilities.TextureToTensorProxy(texture, writer, m_Grayscale);
104104
DestroyTexture(texture);
105105
return numWritten;
106106
}

com.unity.ml-agents/Runtime/Sensors/ISensor.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ public interface ISensor
3131
int[] GetObservationShape();
3232

3333
/// <summary>
34-
/// Write the observation data directly to the <see cref="WriteAdapter"/>.
34+
/// Write the observation data directly to the <see cref="ObservationWriter"/>.
3535
/// This is considered an advanced interface; for a simpler approach, use
3636
/// <see cref="SensorBase"/> and override <see cref="SensorBase.WriteObservation"/> instead.
3737
/// Note that this (and <see cref="GetCompressedObservation"/>) may
3838
/// be called multiple times per agent step, so should not mutate any internal state.
3939
/// </summary>
40-
/// <param name="adapter">Where the observations will be written to.</param>
40+
/// <param name="writer">Where the observations will be written to.</param>
4141
/// <returns>The number of elements written.</returns>
42-
int Write(WriteAdapter adapter);
42+
int Write(ObservationWriter writer);
4343

4444
/// <summary>
4545
/// Return a compressed representation of the observation. For small observations,

com.unity.ml-agents/Runtime/Sensors/WriteAdapter.cs renamed to com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace MLAgents.Sensors
88
/// <summary>
99
/// Allows sensors to write to both TensorProxy and float arrays/lists.
1010
/// </summary>
11-
public class WriteAdapter
11+
public class ObservationWriter
1212
{
1313
IList<float> m_Data;
1414
int m_Offset;
@@ -18,10 +18,10 @@ public class WriteAdapter
1818

1919
TensorShape m_TensorShape;
2020

21-
internal WriteAdapter() { }
21+
internal ObservationWriter() { }
2222

2323
/// <summary>
24-
/// Set the adapter to write to an IList at the given channelOffset.
24+
/// Set the writer to write to an IList at the given channelOffset.
2525
/// </summary>
2626
/// <param name="data">Float array or list that will be written to.</param>
2727
/// <param name="shape">Shape of the observations to be written.</param>
@@ -44,7 +44,7 @@ internal void SetTarget(IList<float> data, int[] shape, int offset)
4444
}
4545

4646
/// <summary>
47-
/// Set the adapter to write to a TensorProxy at the given batch and channel offset.
47+
/// Set the writer to write to a TensorProxy at the given batch and channel offset.
4848
/// </summary>
4949
/// <param name="tensorProxy">Tensor proxy that will be written to.</param>
5050
/// <param name="batchIndex">Batch index in the tensor proxy (i.e. the index of the Agent).</param>

com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,11 @@ internal void SetRayPerceptionInput(RayPerceptionInput rayInput)
286286

287287
/// <summary>
288288
/// Computes the ray perception observations and saves them to the provided
289-
/// <see cref="WriteAdapter"/>.
289+
/// <see cref="ObservationWriter"/>.
290290
/// </summary>
291-
/// <param name="adapter">Where the ray perception observations are written to.</param>
291+
/// <param name="writer">Where the ray perception observations are written to.</param>
292292
/// <returns></returns>
293-
public int Write(WriteAdapter adapter)
293+
public int Write(ObservationWriter writer)
294294
{
295295
using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive"))
296296
{
@@ -322,8 +322,8 @@ public int Write(WriteAdapter adapter)
322322

323323
rayOutput.ToFloatArray(numDetectableTags, rayIndex, m_Observations);
324324
}
325-
// Finally, add the observations to the WriteAdapter
326-
adapter.AddRange(m_Observations);
325+
// Finally, add the observations to the ObservationWriter
326+
writer.AddRange(m_Observations);
327327
}
328328
return m_Observations.Length;
329329
}

com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ public LayerMask RayLayerMask
112112

113113
[HideInInspector, SerializeField, FormerlySerializedAs("observationStacks")]
114114
[Range(1, 50)]
115-
[Tooltip("Whether to stack previous observations. Using 1 means no previous observations.")]
115+
[Tooltip("Number of raycast results that will be stacked before being fed to the neural network.")]
116116
int m_ObservationStacks = 1;
117117

118118
/// <summary>

com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ public byte[] GetCompressedObservation()
6868
}
6969

7070
/// <inheritdoc/>
71-
public int Write(WriteAdapter adapter)
71+
public int Write(ObservationWriter writer)
7272
{
7373
using (TimerStack.Instance.Scoped("RenderTextureSensor.Write"))
7474
{
7575
var texture = ObservationToTexture(m_RenderTexture);
76-
var numWritten = Utilities.TextureToTensorProxy(texture, adapter, m_Grayscale);
76+
var numWritten = Utilities.TextureToTensorProxy(texture, writer, m_Grayscale);
7777
DestroyTexture(texture);
7878
return numWritten;
7979
}

com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class StackingSensor : ISensor
3333
float[][] m_StackedObservations;
3434

3535
int m_CurrentIndex;
36-
WriteAdapter m_LocalAdapter = new WriteAdapter();
36+
ObservationWriter m_LocalWriter = new ObservationWriter();
3737

3838
/// <summary>
3939
/// Initializes the sensor.
@@ -76,19 +76,19 @@ public StackingSensor(ISensor wrapped, int numStackedObservations)
7676
}
7777

7878
/// <inheritdoc/>
79-
public int Write(WriteAdapter adapter)
79+
public int Write(ObservationWriter writer)
8080
{
81-
// First, call the wrapped sensor's write method. Make sure to use our own adapter, not the passed one.
81+
// First, call the wrapped sensor's write method. Make sure to use our own writer, not the passed one.
8282
var wrappedShape = m_WrappedSensor.GetObservationShape();
83-
m_LocalAdapter.SetTarget(m_StackedObservations[m_CurrentIndex], wrappedShape, 0);
84-
m_WrappedSensor.Write(m_LocalAdapter);
83+
m_LocalWriter.SetTarget(m_StackedObservations[m_CurrentIndex], wrappedShape, 0);
84+
m_WrappedSensor.Write(m_LocalWriter);
8585

8686
// Now write the saved observations (oldest first)
8787
var numWritten = 0;
8888
for (var i = 0; i < m_NumStackedObservations; i++)
8989
{
9090
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
91-
adapter.AddRange(m_StackedObservations[obsIndex], numWritten);
91+
writer.AddRange(m_StackedObservations[obsIndex], numWritten);
9292
numWritten += m_UnstackedObservationSize;
9393
}
9494

com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public VectorSensor(int observationSize, string name = null)
3333
}
3434

3535
/// <inheritdoc/>
36-
public int Write(WriteAdapter adapter)
36+
public int Write(ObservationWriter writer)
3737
{
3838
var expectedObservations = m_Shape[0];
3939
if (m_Observations.Count > expectedObservations)
@@ -57,7 +57,7 @@ public int Write(WriteAdapter adapter)
5757
m_Observations.Add(0);
5858
}
5959
}
60-
adapter.AddRange(m_Observations);
60+
writer.AddRange(m_Observations);
6161
return expectedObservations;
6262
}
6363

com.unity.ml-agents/Runtime/Utilities.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ namespace MLAgents
77
internal static class Utilities
88
{
99
/// <summary>
10-
/// Puts a Texture2D into a WriteAdapter.
10+
/// Puts a Texture2D into a ObservationWriter.
1111
/// </summary>
1212
/// <param name="texture">
1313
/// The texture to be put into the tensor.
1414
/// </param>
15-
/// <param name="adapter">
16-
/// Adapter to fill with Texture data.
15+
/// <param name="obsWriter">
16+
/// Writer to fill with Texture data.
1717
/// </param>
1818
/// <param name="grayScale">
1919
/// If set to <c>true</c> the textures will be converted to grayscale before
@@ -22,7 +22,7 @@ internal static class Utilities
2222
/// <returns>The number of floats written</returns>
2323
internal static int TextureToTensorProxy(
2424
Texture2D texture,
25-
WriteAdapter adapter,
25+
ObservationWriter obsWriter,
2626
bool grayScale)
2727
{
2828
var width = texture.width;
@@ -38,15 +38,15 @@ internal static int TextureToTensorProxy(
3838
var currentPixel = texturePixels[(height - h - 1) * width + w];
3939
if (grayScale)
4040
{
41-
adapter[h, w, 0] =
41+
obsWriter[h, w, 0] =
4242
(currentPixel.r + currentPixel.g + currentPixel.b) / 3f / 255.0f;
4343
}
4444
else
4545
{
4646
// For Color32, the r, g and b values are between 0 and 255.
47-
adapter[h, w, 0] = currentPixel.r / 255.0f;
48-
adapter[h, w, 1] = currentPixel.g / 255.0f;
49-
adapter[h, w, 2] = currentPixel.b / 255.0f;
47+
obsWriter[h, w, 0] = currentPixel.r / 255.0f;
48+
obsWriter[h, w, 1] = currentPixel.g / 255.0f;
49+
obsWriter[h, w, 2] = currentPixel.b / 255.0f;
5050
}
5151
}
5252
}

0 commit comments

Comments
 (0)