Skip to content

[MLA-1634] Remove SensorComponent.GetObservationShape() #5172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ public override ISensor CreateSensor()
{
return new BasicSensor(basicController);
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { BasicController.k_Extents };
}
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public abstract class SensorBase : ISensor
{
/// <summary>
/// Write the observations to the output buffer. This size of the buffer will be product
/// of the sizes returned by <see cref="GetObservationShape"/>.
/// of the Shape array values returned by <see cref="ObservationSpec"/>.
/// </summary>
/// <param name="output"></param>
public abstract void WriteObservation(float[] output);
Expand All @@ -28,7 +28,7 @@ public abstract class SensorBase : ISensor
/// <returns>The number of elements written.</returns>
public virtual int Write(ObservationWriter writer)
{
// TODO reuse buffer for similar agents, don't call GetObservationShape()
// TODO reuse buffer for similar agents
var numFloats = this.ObservationSize();
float[] buffer = new float[numFloats];
WriteObservation(buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,5 @@ public override ISensor CreateSensor()
}
return m_Sensor;
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
var width = TestTexture.width;
var height = TestTexture.height;
var observationShape = new[] { height, width, 3 };

var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
if (stacks > 1)
{
observationShape[2] *= stacks;
}

return observationShape;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,5 @@ public override ISensor CreateSensor()
return new Match3Sensor(board, ObservationType, SensorName);
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
var board = GetComponent<AbstractBoard>();
if (board == null)
{
return System.Array.Empty<int>();
}

var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1;
return ObservationType == Match3ObservationType.Vector ?
new[] { board.Rows * board.Columns * (board.NumCellTypes + specialSize) } :
new[] { board.Rows, board.Columns, board.NumCellTypes + specialSize };
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,6 @@ public override ISensor CreateSensor()
return new PhysicsBodySensor(RootBody, Settings, sensorName);
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}

// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;

foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
{
numJointObservations += ArticulationBodyJointExtractor.NumObservations(articBody, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}
}

}
Expand Down
7 changes: 0 additions & 7 deletions com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -917,13 +917,6 @@ public ObservationSpec GetObservationSpec()
return m_ObservationSpec;
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
var shape = m_ObservationSpec.Shape;
return new int[] { shape[0], shape[1], shape[2] };
}

/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ string sensorName
}

#if UNITY_2020_1_OR_NEWER
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName = null)
{
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody);
m_PoseExtractor = poseExtractor;
Expand All @@ -57,7 +57,7 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin

var numJointExtractorObservations = 0;
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
foreach (var articBody in poseExtractor.GetEnabledArticulationBodies())
{
var jointExtractor = new ArticulationBodyJointExtractor(articBody);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
Expand All @@ -67,6 +67,7 @@ public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settin
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);
m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations);
}

#endif

/// <inheritdoc/>
Expand Down Expand Up @@ -126,6 +127,5 @@ public BuiltInSensorType GetBuiltInSensorType()
{
return BuiltInSensorType.PhysicsBodySensor;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,6 @@ public override ISensor CreateSensor()
return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName);
}

/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}

var poseExtractor = GetPoseExtractor();
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);

var numJointObservations = 0;
foreach (var rb in poseExtractor.GetEnabledRigidbodies())
{
var joint = rb.GetComponent<Joint>();
numJointObservations += RigidBodyJointExtractor.NumObservations(rb, joint, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}

/// <summary>
/// Get the DisplayNodes of the hierarchy.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ public void TestVectorObservations()
sensorComponent.ObservationType = Match3ObservationType.Vector;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3 * 3 * 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

var expectedObs = new float[]
{
Expand Down Expand Up @@ -63,9 +62,8 @@ public void TestVectorObservationsSpecial()
sensorComponent.ObservationType = Match3ObservationType.Vector;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3 * 3 * (2 + 3));
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

var expectedObs = new float[]
{
Expand All @@ -76,7 +74,6 @@ public void TestVectorObservationsSpecial()
SensorTestHelper.CompareObservation(sensor, expectedObs);
}


[Test]
public void TestVisualObservations()
{
Expand All @@ -92,9 +89,8 @@ public void TestVisualObservations()
sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);

Expand Down Expand Up @@ -136,9 +132,8 @@ public void TestVisualObservationsSpecial()
sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2 + 3);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionSpec().SensorCompressionType);

Expand Down Expand Up @@ -174,9 +169,8 @@ public void TestCompressedVisualObservations()
sensorComponent.ObservationType = Match3ObservationType.CompressedVisual;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);

Expand All @@ -191,8 +185,6 @@ public void TestCompressedVisualObservations()
Assert.AreEqual(expectedPng, pngData);
}



[Test]
public void TestCompressedVisualObservationsSpecial()
{
Expand All @@ -214,9 +206,8 @@ public void TestCompressedVisualObservationsSpecial()
sensorComponent.ObservationType = Match3ObservationType.CompressedVisual;
var sensor = sensorComponent.CreateSensor();

var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(InplaceArray<int>.FromList(expectedShape), sensor.GetObservationSpec().Shape);
var expectedShape = new InplaceArray<int>(3, 3, 2 + 3);
Assert.AreEqual(expectedShape, sensor.GetObservationSpec().Shape);

Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionSpec().SensorCompressionType);

Expand All @@ -229,7 +220,6 @@ public void TestCompressedVisualObservationsSpecial()
}
var expectedPng = LoadPNGs(pathPrefix, 2);
Assert.AreEqual(expectedPng, concatenatedPngData);

}

/// <summary>
Expand Down Expand Up @@ -306,7 +296,6 @@ byte[] LoadPNGs(string pathPrefix, int numExpected)
}

return bytesOut.ToArray();

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ public void OneChannelDepthOne()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

var expectedShape = new InplaceArray<int>(10, 10, 1);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}


Expand All @@ -51,9 +50,8 @@ public void OneChannelDepthTwo()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

var expectedShape = new InplaceArray<int>(10, 10, 2);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}

[Test]
Expand All @@ -66,8 +64,8 @@ public void TwoChannelsDepthTwoOne()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 3 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 3);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);

}

Expand All @@ -81,9 +79,8 @@ public void TwoChannelsDepthThreeThree()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 6 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());

var expectedShape = new InplaceArray<int>(10, 10, 6);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ public void OneChannel()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 1 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 1);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);

}

[Test]
Expand All @@ -48,8 +49,9 @@ public void TwoChannel()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 2 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 2);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);

}

[Test]
Expand All @@ -62,8 +64,9 @@ public void SevenChannel()
1f, 1f, 10, 10, LayerMask.GetMask("Default"), false, colors);
gridSensor.Start();

int[] expectedShape = { 10, 10, 7 };
GridObsTestUtils.AssertArraysAreEqual(expectedShape, gridSensor.GetObservationShape());
var expectedShape = new InplaceArray<int>(10, 10, 7);
Assert.AreEqual(expectedShape, gridSensor.GetObservationSpec().Shape);

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,6 @@ public static float[][] DuplicateArray(float[] array, int numCopies)
return duplicated;
}

/// <summary>
/// Asserts that 2 int arrays are the same
/// </summary>
/// <param name="expected">The expected array</param>
/// <param name="actual">The actual array</param>
public static void AssertArraysAreEqual(int[] expected, int[] actual)
{
Assert.AreEqual(expected.Length, actual.Length, "Lengths are not the same");
for (int i = 0; i < actual.Length; i++)
{
Assert.AreEqual(expected[i], actual[i], "Got " + Array2Str(actual) + ", expected " + Array2Str(expected));
}
}

/// <summary>
/// Asserts that 2 float arrays are the same
/// </summary>
/// <param name="expected">The expected array</param>
/// <param name="actual">The actual array</param>
public static void AssertArraysAreEqual(float[] expected, float[] actual)
{
Assert.AreEqual(expected.Length, actual.Length, "Lengths are not the same");
for (int i = 0; i < actual.Length; i++)
{
Assert.AreEqual(expected[i], actual[i], "Got " + Array2Str(actual) + ", expected " + Array2Str(expected));
}
}

/// <summary>
/// Asserts that the sub-arrays of the total array are equal to specific subarrays at specific subarray indicies and equal to a default everywhere else.
Expand Down
Loading