Skip to content

Commit bf0919d

Browse files
committed
Added testing and returning staked observation flat.
1 parent f8b17c7 commit bf0919d

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

com.unity.ml-agents/Runtime/Agent.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1190,7 +1190,7 @@ public ReadOnlyCollection<float> GetObservations()
11901190
/// <see cref="Heuristic(in ActionBuffers)"/> method to avoid recomputing the observations.
11911191
/// </summary>
11921192
/// <returns>A read-only view of the stacked observations list.</returns>
1193-
public ReadOnlyCollection<ReadOnlyCollection<float>> GetStackedObservations()
1193+
public ReadOnlyCollection<float> GetStackedObservations()
11941194
{
11951195
return stackedCollectObservationsSensor.GetStackedObservations();
11961196
}

com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ public class StackingSensor : ISensor, IBuiltInSensor
3636
/// Buffer of previous observations
3737
/// </summary>
3838
float[][] m_StackedObservations;
39+
//[
40+
//[1,2]
41+
//[3,4]
42+
//[5,6]
43+
//]
3944

4045
byte[][] m_StackedCompressedObservations;
4146

@@ -286,15 +291,15 @@ public BuiltInSensorType GetBuiltInSensorType()
286291
/// Returns a read-only view of the observations that added.
287292
/// </summary>
288293
/// <returns>A read-only view of the observations list.</returns>
289-
internal ReadOnlyCollection<ReadOnlyCollection<float>> GetStackedObservations()
294+
internal ReadOnlyCollection<float> GetStackedObservations()
290295
{
291-
List<ReadOnlyCollection<float>> layer = new List<ReadOnlyCollection<float>>();
292-
foreach (float[] l in m_StackedObservations)
296+
List<float> observations = new List<float>();
297+
for (var i = 0; i < m_NumStackedObservations; i++)
293298
{
294-
layer.Add(l.ToList().AsReadOnly());
299+
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
300+
observations.AddRange(m_StackedObservations[obsIndex].ToList());
295301
}
296-
297-
return layer.AsReadOnly();
302+
return observations.AsReadOnly();
298303
}
299304
}
300305
}

com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,29 +76,41 @@ public void AssertStackingReset()
7676
public void TestVectorStacking()
7777
{
7878
VectorSensor wrapped = new VectorSensor(2);
79-
ISensor sensor = new StackingSensor(wrapped, 3);
79+
StackingSensor sensor = new StackingSensor(wrapped, 3);
8080

8181
wrapped.AddObservation(new[] { 1f, 2f });
8282
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f });
83+
var data = sensor.GetStackedObservations();
84+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 0f, 0f, 1f, 2f }));
8385

8486
sensor.Update();
8587
wrapped.AddObservation(new[] { 3f, 4f });
8688
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f });
89+
data = sensor.GetStackedObservations();
90+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 1f, 2f, 3f, 4f }));
8791

8892
sensor.Update();
8993
wrapped.AddObservation(new[] { 5f, 6f });
9094
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f, 5f, 6f });
95+
data = sensor.GetStackedObservations();
96+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 1f, 2f, 3f, 4f, 5f, 6f }));
9197

9298
sensor.Update();
9399
wrapped.AddObservation(new[] { 7f, 8f });
94100
SensorTestHelper.CompareObservation(sensor, new[] { 3f, 4f, 5f, 6f, 7f, 8f });
101+
data = sensor.GetStackedObservations();
102+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 3f, 4f, 5f, 6f, 7f, 8f }));
95103

96104
sensor.Update();
97105
wrapped.AddObservation(new[] { 9f, 10f });
98106
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
107+
data = sensor.GetStackedObservations();
108+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f }));
99109

100110
// Check that if we don't call Update(), the same observations are produced
101111
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
112+
data = sensor.GetStackedObservations();
113+
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f }));
102114
}
103115

104116
[Test]

0 commit comments

Comments
 (0)