Skip to content

Commit 8142aaf

Browse files
author
Chris Elion
authored
access to observations in Heuristic (#3825)
* access to observations in Heuristic * changelog
1 parent e88ce26 commit 8142aaf

File tree

4 files changed

+31
-2
lines changed

4 files changed

+31
-2
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ and this project adheres to
7070
- Academy.InferenceSeed property was added. This is used to initialize the
7171
random number generator in ModelRunner, and is incremented for each ModelRunner. (#3823)
7272
- Updated Barracuda to 0.6.3-preview.
73-
- Model updates can now happen asynchronously with environment steps for better performance. (#3690)
74-
- `num_updates` and `train_interval` for SAC were replaced with `steps_per_update`. (#3690)
73+
- Added `Agent.GetObservations(), which returns a read-only view of the observations
74+
added in `CollectObservations()`. (#3825)
75+
- Model updates can now happen asynchronously with environment steps for better performance. (#3690)
76+
- `num_updates` and `train_interval` for SAC were replaced with `steps_per_update`. (#3690)
7577

7678
### Bug Fixes
7779

com.unity.ml-agents/Runtime/Agent.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Collections.ObjectModel;
34
using UnityEngine;
45
using Barracuda;
56
using MLAgents.Sensors;
@@ -691,6 +692,17 @@ public virtual void CollectObservations(VectorSensor sensor)
691692
{
692693
}
693694

695+
/// <summary>
696+
/// Returns a read-only view of the observations that were generated in
697+
/// <see cref="CollectObservations(VectorSensor)"/>. This is mainly useful inside of a
698+
/// <see cref="Heuristic(float[])"/> method to avoid recomputing the observations.
699+
/// </summary>
700+
/// <returns>A read-only view of the observations list.</returns>
701+
public ReadOnlyCollection<float> GetObservations()
702+
{
703+
return collectObservationsSensor.GetObservations();
704+
}
705+
694706
/// <summary>
695707
/// Collects the masks for discrete actions.
696708
/// When using discrete actions, the agent will not perform the masked action.

com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Collections.Generic;
2+
using System.Collections.ObjectModel;
23
using UnityEngine;
34

45
namespace MLAgents.Sensors
@@ -60,6 +61,15 @@ public int Write(WriteAdapter adapter)
6061
return expectedObservations;
6162
}
6263

64+
/// <summary>
65+
/// Returns a read-only view of the observations that added.
66+
/// </summary>
67+
/// <returns>A read-only view of the observations list.</returns>
68+
internal ReadOnlyCollection<float> GetObservations()
69+
{
70+
return m_Observations.AsReadOnly();
71+
}
72+
6373
/// <inheritdoc/>
6474
public void Update()
6575
{

com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ public override void OnEpisodeBegin()
9696

9797
public override void Heuristic(float[] actionsOut)
9898
{
99+
var obs = GetObservations();
100+
actionsOut[0] = obs[0];
99101
heuristicCalls++;
100102
}
101103
}
@@ -667,6 +669,9 @@ public void TestHeuristicPolicyStepsSensors()
667669
Assert.AreEqual(numSteps, agent1.heuristicCalls);
668670
Assert.AreEqual(numSteps, agent1.sensor1.numWriteCalls);
669671
Assert.AreEqual(numSteps, agent1.sensor2.numCompressedCalls);
672+
673+
// Make sure the Heuristic method read the observation and set the action
674+
Assert.AreEqual(agent1.collectObservationsCallsForEpisode, agent1.GetAction()[0]);
670675
}
671676
}
672677

0 commit comments

Comments
 (0)