Skip to content

Commit 5102adb

Browse files
author
Chris Elion
authored
[MLA-345] float visual observations (#3148)
* pass shape to WriteAdapter * handle floats on python side * cleanup * whitespace * rename GetFloatObservationShape, support uncompressed in RenderTexture sensor * numpy float32 * remove unused using * Float sensor and unit test * replace asserts with exceptions, docstrings
1 parent ae1e6c9 commit 5102adb

26 files changed

+858
-572
lines changed

UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public TestSensor(string n)
8080
sensorName = n;
8181
}
8282

83-
public int[] GetFloatObservationShape()
83+
public int[] GetObservationShape()
8484
{
8585
return new[] { 0 };
8686
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
using NUnit.Framework;
2+
using UnityEngine;
3+
using MLAgents.Sensor;
4+
5+
namespace MLAgents.Tests
6+
{
7+
public class Float2DSensor : ISensor
8+
{
9+
public int Width { get; }
10+
public int Height { get; }
11+
string m_Name;
12+
int[] m_Shape;
13+
public float[,] floatData;
14+
15+
public Float2DSensor(int width, int height, string name)
16+
{
17+
Width = width;
18+
Height = height;
19+
m_Name = name;
20+
m_Shape = new[] { height, width, 1 };
21+
floatData = new float[Height, Width];
22+
}
23+
24+
public Float2DSensor(float[,] floatData, string name)
25+
{
26+
this.floatData = floatData;
27+
Height = floatData.GetLength(0);
28+
Width = floatData.GetLength(1);
29+
m_Name = name;
30+
m_Shape = new[] { Height, Width, 1 };
31+
}
32+
33+
public string GetName()
34+
{
35+
return m_Name;
36+
}
37+
38+
public int[] GetObservationShape()
39+
{
40+
return m_Shape;
41+
}
42+
43+
public byte[] GetCompressedObservation()
44+
{
45+
return null;
46+
}
47+
48+
public int Write(WriteAdapter adapter)
49+
{
50+
using (TimerStack.Instance.Scoped("Float2DSensor.Write"))
51+
{
52+
for (var h = 0; h < Height; h++)
53+
{
54+
for (var w = 0; w < Width; w++)
55+
{
56+
adapter[h, w, 0] = floatData[h, w];
57+
}
58+
}
59+
var numWritten = Height * Width;
60+
return numWritten;
61+
}
62+
}
63+
64+
public void Update() { }
65+
66+
public SensorCompressionType GetCompressionType()
67+
{
68+
return SensorCompressionType.None;
69+
}
70+
}
71+
72+
public class FloatVisualSensorTests
73+
{
74+
[Test]
75+
public void TestFloat2DSensorWrite()
76+
{
77+
var sensor = new Float2DSensor(3, 4, "floatsensor");
78+
for (var h = 0; h < 4; h++)
79+
{
80+
for (var w = 0; w < 3; w++)
81+
{
82+
sensor.floatData[h, w] = 3 * h + w;
83+
}
84+
}
85+
86+
var output = new float[12];
87+
var writer = new WriteAdapter();
88+
writer.SetTarget(output, sensor.GetObservationShape(), 0);
89+
sensor.Write(writer);
90+
for (var i = 0; i < 9; i++)
91+
{
92+
Assert.AreEqual(i, output[i]);
93+
}
94+
}
95+
96+
[Test]
97+
public void TestFloat2DSensorExternalData()
98+
{
99+
var data = new float[4, 3];
100+
var sensor = new Float2DSensor(data, "floatsensor");
101+
Assert.AreEqual(sensor.Height, 4);
102+
Assert.AreEqual(sensor.Width, 3);
103+
}
104+
}
105+
}

UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/FloatVisualSensorTests.cs.meta

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public void TestCtor()
1212
ISensor wrapped = new VectorSensor(4);
1313
ISensor sensor = new StackingSensor(wrapped, 4);
1414
Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName());
15-
Assert.AreEqual(sensor.GetFloatObservationShape(), new [] {16});
15+
Assert.AreEqual(sensor.GetObservationShape(), new [] {16});
1616
}
1717

1818
[Test]

UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public static void CompareObservation(ISensor sensor, float[] expected)
1818
Assert.AreEqual(fill, output[0]);
1919

2020
WriteAdapter writer = new WriteAdapter();
21-
writer.SetTarget(output, 0);
21+
writer.SetTarget(output, sensor.GetObservationShape(), 0);
2222

2323
// Make sure WriteAdapter didn't touch anything
2424
Assert.AreEqual(fill, output[0]);

UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,26 @@ public void TestWritesToIList()
1515
{
1616
WriteAdapter writer = new WriteAdapter();
1717
var buffer = new[] { 0f, 0f, 0f };
18+
var shape = new[] { 3 };
1819

19-
writer.SetTarget(buffer, 0);
20+
writer.SetTarget(buffer, shape, 0);
2021
// Elementwise writes
2122
writer[0] = 1f;
2223
writer[2] = 2f;
2324
Assert.AreEqual(new[] { 1f, 0f, 2f }, buffer);
2425

2526
// Elementwise writes with offset
26-
writer.SetTarget(buffer, 1);
27+
writer.SetTarget(buffer, shape, 1);
2728
writer[0] = 3f;
2829
Assert.AreEqual(new[] { 1f, 3f, 2f }, buffer);
2930

3031
// AddRange
31-
writer.SetTarget(buffer, 0);
32+
writer.SetTarget(buffer, shape, 0);
3233
writer.AddRange(new [] {4f, 5f});
3334
Assert.AreEqual(new[] { 4f, 5f, 2f }, buffer);
3435

3536
// AddRange with offset
36-
writer.SetTarget(buffer, 1);
37+
writer.SetTarget(buffer, shape, 1);
3738
writer.AddRange(new [] {6f, 7f});
3839
Assert.AreEqual(new[] { 4f, 6f, 7f }, buffer);
3940
}
@@ -47,12 +48,13 @@ public void TestWritesToTensor()
4748
valueType = TensorProxy.TensorType.FloatingPoint,
4849
data = new Tensor(2, 3)
4950
};
50-
writer.SetTarget(t, 0, 0);
51+
var shape = new[] { 3 };
52+
writer.SetTarget(t, shape, 0, 0);
5153
Assert.AreEqual(0f, t.data[0, 0]);
5254
writer[0] = 1f;
5355
Assert.AreEqual(1f, t.data[0, 0]);
5456

55-
writer.SetTarget(t, 1, 1);
57+
writer.SetTarget(t, shape, 1, 1);
5658
writer[0] = 2f;
5759
writer[1] = 3f;
5860
// [0, 0] shouldn't change
@@ -67,7 +69,7 @@ public void TestWritesToTensor()
6769
data = new Tensor(2, 3)
6870
};
6971

70-
writer.SetTarget(t, 1, 1);
72+
writer.SetTarget(t, shape, 1, 1);
7173
writer.AddRange(new [] {-1f, -2f});
7274
Assert.AreEqual(0f, t.data[0, 0]);
7375
Assert.AreEqual(0f, t.data[0, 1]);
@@ -87,11 +89,13 @@ public void TestWritesToTensor3D()
8789
data = new Tensor(2, 2, 2, 3)
8890
};
8991

90-
writer.SetTarget(t, 0, 0);
92+
var shape = new[] { 2, 2, 3 };
93+
94+
writer.SetTarget(t, shape, 0, 0);
9195
writer[1, 0, 1] = 1f;
9296
Assert.AreEqual(1f, t.data[0, 1, 0, 1]);
9397

94-
writer.SetTarget(t, 0, 1);
98+
writer.SetTarget(t, shape, 0, 1);
9599
writer[1, 0, 0] = 2f;
96100
Assert.AreEqual(2f, t.data[0, 1, 0, 1]);
97101
}

0 commit comments

Comments
 (0)