Skip to content

Commit f8b17c7

Browse files
committed
Method to return stacked observations
1 parent 05c0275 commit f8b17c7

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

com.unity.ml-agents/Runtime/Agent.cs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,11 @@ internal struct AgentParameters
320320
/// </summary>
321321
internal VectorSensor collectObservationsSensor;
322322

323+
/// <summary>
324+
/// StackingSensor which is written to by AddVectorObs
325+
/// </summary>
326+
internal StackingSensor stackedCollectObservationsSensor;
327+
323328
private RecursionChecker m_CollectObservationsChecker = new RecursionChecker("CollectObservations");
324329
private RecursionChecker m_OnEpisodeBeginChecker = new RecursionChecker("OnEpisodeBegin");
325330

@@ -981,9 +986,9 @@ internal void InitializeSensors()
981986
collectObservationsSensor = new VectorSensor(param.VectorObservationSize);
982987
if (param.NumStackedVectorObservations > 1)
983988
{
984-
var stackingSensor = new StackingSensor(
989+
stackedCollectObservationsSensor = new StackingSensor(
985990
collectObservationsSensor, param.NumStackedVectorObservations);
986-
sensors.Add(stackingSensor);
991+
sensors.Add(stackedCollectObservationsSensor);
987992
}
988993
else
989994
{
@@ -1179,6 +1184,17 @@ public ReadOnlyCollection<float> GetObservations()
11791184
return collectObservationsSensor.GetObservations();
11801185
}
11811186

1187+
/// <summary>
1188+
/// Returns a read-only view of the stacked observations that were generated in
1189+
/// <see cref="CollectObservations(VectorSensor)"/>. This is mainly useful inside of a
1190+
/// <see cref="Heuristic(in ActionBuffers)"/> method to avoid recomputing the observations.
1191+
/// </summary>
1192+
/// <returns>A read-only view of the stacked observations list.</returns>
1193+
public ReadOnlyCollection<ReadOnlyCollection<float>> GetStackedObservations()
1194+
{
1195+
return stackedCollectObservationsSensor.GetStackedObservations();
1196+
}
1197+
11821198
/// <summary>
11831199
/// Implement `WriteDiscreteActionMask()` to collects the masks for discrete
11841200
/// actions. When using discrete actions, the agent will not perform the masked

com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System;
2+
using System.Collections.Generic;
3+
using System.Collections.ObjectModel;
24
using System.Linq;
35
using UnityEngine;
46
using Unity.Barracuda;
@@ -279,5 +281,20 @@ public BuiltInSensorType GetBuiltInSensorType()
279281
IBuiltInSensor wrappedBuiltInSensor = m_WrappedSensor as IBuiltInSensor;
280282
return wrappedBuiltInSensor?.GetBuiltInSensorType() ?? BuiltInSensorType.Unknown;
281283
}
284+
285+
/// <summary>
286+
/// Returns a read-only view of the observations that added.
287+
/// </summary>
288+
/// <returns>A read-only view of the observations list.</returns>
289+
internal ReadOnlyCollection<ReadOnlyCollection<float>> GetStackedObservations()
290+
{
291+
List<ReadOnlyCollection<float>> layer = new List<ReadOnlyCollection<float>>();
292+
foreach (float[] l in m_StackedObservations)
293+
{
294+
layer.Add(l.ToList().AsReadOnly());
295+
}
296+
297+
return layer.AsReadOnly();
298+
}
282299
}
283300
}

0 commit comments

Comments
 (0)