Skip to content

make StackingSensor public #3701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- The way that UnityEnvironment decides the port was changed. If no port is specified, the behavior will depend on the `file_name` parameter. If it is `None`, 5004 (the editor port) will be used; otherwise 5005 (the base environment port) will be used.
- Fixed an issue where exceptions from environments provided a returncode of 0. (#3680)
- Running `mlagents-learn` with the same `--run-id` twice will no longer overwrite the existing files. (#3705)
- `StackingSensor` was changed from `internal` visibility to `public`

## [0.15.1-preview] - 2020-03-30
### Bug Fixes
Expand Down
20 changes: 18 additions & 2 deletions com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ namespace MLAgents.Sensors
/// For example, 4 stacked sets of observations would be output like
/// | t = now - 3 | t = now -3 | t = now - 2 | t = now |
/// Internally, a circular buffer of arrays is used. The m_CurrentIndex represents the most recent observation.
///
/// Currently, compressed and multidimensional observations are not supported.
/// </summary>
internal class StackingSensor : ISensor
public class StackingSensor : ISensor
{
/// <summary>
/// The wrapped sensor.
Expand All @@ -32,7 +34,7 @@ internal class StackingSensor : ISensor
WriteAdapter m_LocalAdapter = new WriteAdapter();

/// <summary>
///
/// Initializes the sensor.
/// </summary>
/// <param name="wrapped">The wrapped sensor.</param>
/// <param name="numStackedObservations">Number of stacked observations to keep.</param>
Expand All @@ -44,7 +46,16 @@ public StackingSensor(ISensor wrapped, int numStackedObservations)

m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";

if (wrapped.GetCompressionType() != SensorCompressionType.None)
{
throw new UnityAgentsException("StackingSensor doesn't support compressed observations.'");
}

var shape = wrapped.GetObservationShape();
if (shape.Length != 1)
{
throw new UnityAgentsException("Only 1-D observations are supported by StackingSensor");
}
m_Shape = new int[shape.Length];

m_UnstackedObservationSize = wrapped.ObservationSize();
Expand All @@ -62,6 +73,7 @@ public StackingSensor(ISensor wrapped, int numStackedObservations)
}
}

/// <inheritdoc/>
public int Write(WriteAdapter adapter)
{
// First, call the wrapped sensor's write method. Make sure to use our own adapter, not the passed one.
Expand Down Expand Up @@ -90,21 +102,25 @@ public void Update()
m_CurrentIndex = (m_CurrentIndex + 1) % m_NumStackedObservations;
}

/// <inheritdoc/>
public int[] GetObservationShape()
{
return m_Shape;
}

/// <inheritdoc/>
public string GetName()
{
return m_Name;
}

/// <inheritdoc/>
public virtual byte[] GetCompressedObservation()
{
return null;
}

/// <inheritdoc/>
public virtual SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
Expand Down
30 changes: 30 additions & 0 deletions com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,30 @@ public override float[] Heuristic()
}
}

// Simple SensorComponent that sets up a StackingSensor
class StackingComponent : SensorComponent
{
public SensorComponent wrappedComponent;
public int numStacks;

public override ISensor CreateSensor()
{
var wrappedSensor = wrappedComponent.CreateSensor();
return new StackingSensor(wrappedSensor, numStacks);
}

public override int[] GetObservationShape()
{
int[] shape = (int[]) wrappedComponent.GetObservationShape().Clone();
for (var i = 0; i < shape.Length; i++)
{
shape[i] *= numStacks;
}

return shape;
}
}


[Test]
public void CheckSetupAgent()
Expand Down Expand Up @@ -114,6 +138,12 @@ public void CheckSetupAgent()
sensorComponent.detectableTags = new List<string> { "Player", "Respawn" };
sensorComponent.raysPerDirection = 3;

// Make a StackingSensor that wraps the RayPerceptionSensorComponent3D
// This isn't necessarily practical, just to ensure that it can be done
var wrappingSensorComponent = gameObject.AddComponent<StackingComponent>();
wrappingSensorComponent.wrappedComponent = sensorComponent;
wrappingSensorComponent.numStacks = 3;

// ISensor isn't set up yet.
Assert.IsNull(sensorComponent.raySensor);

Expand Down