Skip to content

Commit e4e9c51

Browse files
author
Chris Elion
authored
[MLA-1824] make SensorComponent return ISensor[] (#5181)
* Make SensorComponent return an array * split match3 sensors, partial retrain * docstrings, migration, changelog, cleanup
1 parent 61beb8c commit e4e9c51

33 files changed

+344
-294
lines changed

Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicSensorComponent.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ public class BasicSensorComponent : SensorComponent
1515
/// Creates a BasicSensor.
1616
/// </summary>
1717
/// <returns></returns>
18-
public override ISensor CreateSensor()
18+
public override ISensor[] CreateSensors()
1919
{
20-
return new BasicSensor(basicController);
20+
return new ISensor[] { new BasicSensor(basicController) };
2121
}
2222
}
2323

Binary file not shown.
Binary file not shown.

Project/Assets/ML-Agents/TestScenes/TestCompressedTexture/TestTextureSensorComponent.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ public string SensorName
2323

2424

2525
/// <inheritdoc/>
26-
public override ISensor CreateSensor()
26+
public override ISensor[] CreateSensors()
2727
{
2828
m_Sensor = new TestTextureSensor(TestTexture, SensorName, CompressionType);
2929
if (ObservationStacks != 1)
3030
{
31-
return new StackingSensor(m_Sensor, ObservationStacks);
31+
return new ISensor[] { new StackingSensor(m_Sensor, ObservationStacks) };
3232
}
33-
return m_Sensor;
33+
return new ISensor[] { m_Sensor };
3434
}
3535
}
3636

com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs

Lines changed: 65 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44

55
namespace Unity.MLAgents.Extensions.Match3
66
{
7+
8+
/// <summary>
9+
/// Delegate that provides integer values at a given (x,y) coordinate.
10+
/// </summary>
11+
/// <param name="x"></param>
12+
/// <param name="y"></param>
13+
public delegate int GridValueProvider(int x, int y);
14+
715
/// <summary>
816
/// Type of observations to generate.
917
///
@@ -32,66 +40,68 @@ public enum Match3ObservationType
3240

3341
/// <summary>
3442
/// Sensor for Match3 games. Can generate either vector, compressed visual,
35-
/// or uncompressed visual observations. Uses AbstractBoard.GetCellType()
36-
/// and AbstractBoard.GetSpecialType() to determine the observation values.
43+
/// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values.
3744
/// </summary>
3845
public class Match3Sensor : ISensor, IBuiltInSensor
3946
{
4047
private Match3ObservationType m_ObservationType;
41-
private AbstractBoard m_Board;
4248
private ObservationSpec m_ObservationSpec;
43-
private int[] m_SparseChannelMapping;
4449
private string m_Name;
4550

4651
private int m_Rows;
4752
private int m_Columns;
48-
private int m_NumCellTypes;
49-
private int m_NumSpecialTypes;
50-
51-
private int SpecialTypeSize
52-
{
53-
get { return m_NumSpecialTypes == 0 ? 0 : m_NumSpecialTypes + 1; }
54-
}
53+
private GridValueProvider m_GridValues;
54+
private int m_OneHotSize;
5555

5656
/// <summary>
57-
/// Create a sensor for the board with the specified observation type.
57+
/// Create a sensor for the GridValueProvider with the specified observation type.
5858
/// </summary>
59-
/// <param name="board"></param>
60-
/// <param name="obsType"></param>
61-
/// <param name="name"></param>
62-
public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string name)
59+
/// <remarks>
60+
/// Use Match3Sensor.CellTypeSensor() or Match3Sensor.SpecialTypeSensor() instead of calling
61+
/// the constructor directly.
62+
/// </remarks>
63+
/// <param name="board">The abstract board. This is only used to get the size.</param>
64+
/// <param name="gvp">The GridValueProvider, should be either board.GetCellType or board.GetSpecialType.</param>
65+
/// <param name="oneHotSize">The number of possible values that the GridValueProvider can return.</param>
66+
/// <param name="obsType">Whether to produce vector or visual observations</param>
67+
/// <param name="name">Name of the sensor.</param>
68+
public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name)
6369
{
64-
m_Board = board;
6570
m_Name = name;
6671
m_Rows = board.Rows;
6772
m_Columns = board.Columns;
68-
m_NumCellTypes = board.NumCellTypes;
69-
m_NumSpecialTypes = board.NumSpecialTypes;
73+
m_GridValues = gvp;
74+
m_OneHotSize = oneHotSize;
7075

7176
m_ObservationType = obsType;
7277
m_ObservationSpec = obsType == Match3ObservationType.Vector
73-
? ObservationSpec.Vector(m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize))
74-
: ObservationSpec.Visual(m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize);
75-
76-
// See comment in GetCompressedObservation()
77-
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);
78-
m_SparseChannelMapping = new int[cellTypePaddedSize + SpecialTypeSize];
79-
// If we have 4 cell types and 2 special types (3 special size), we'd have
80-
// [0, 1, 2, 3, -1, -1, 4, 5, 6]
81-
for (var i = 0; i < m_NumCellTypes; i++)
82-
{
83-
m_SparseChannelMapping[i] = i;
84-
}
78+
? ObservationSpec.Vector(m_Rows * m_Columns * oneHotSize)
79+
: ObservationSpec.Visual(m_Rows, m_Columns, oneHotSize);
80+
}
8581

86-
for (var i = m_NumCellTypes; i < cellTypePaddedSize; i++)
87-
{
88-
m_SparseChannelMapping[i] = -1;
89-
}
82+
/// <summary>
83+
/// Create a sensor that encodes the board cells as observations.
84+
/// </summary>
85+
/// <param name="board">The abstract board.</param>
86+
/// <param name="obsType">Whether to produce vector or visual observations</param>
87+
/// <param name="name">Name of the sensor.</param>
88+
/// <returns></returns>
89+
public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name)
90+
{
91+
return new Match3Sensor(board, board.GetCellType, board.NumCellTypes, obsType, name);
92+
}
9093

91-
for (var i = 0; i < SpecialTypeSize; i++)
92-
{
93-
m_SparseChannelMapping[cellTypePaddedSize + i] = i + m_NumCellTypes;
94-
}
94+
/// <summary>
95+
/// Create a sensor that encodes the cell special types as observations.
96+
/// </summary>
97+
/// <param name="board">The abstract board.</param>
98+
/// <param name="obsType">Whether to produce vector or visual observations</param>
99+
/// <param name="name">Name of the sensor.</param>
100+
/// <returns></returns>
101+
public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name)
102+
{
103+
var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1;
104+
return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name);
95105
}
96106

97107
/// <inheritdoc/>
@@ -103,14 +113,14 @@ public ObservationSpec GetObservationSpec()
103113
/// <inheritdoc/>
104114
public int Write(ObservationWriter writer)
105115
{
106-
if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes)
107-
{
108-
Debug.LogWarning(
109-
$"Board shape changes since sensor initialization. This may cause unexpected results. " +
110-
$"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " +
111-
$"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}"
112-
);
113-
}
116+
// if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes)
117+
// {
118+
// Debug.LogWarning(
119+
// $"Board shape changes since sensor initialization. This may cause unexpected results. " +
120+
// $"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " +
121+
// $"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}"
122+
// );
123+
// }
114124

115125
if (m_ObservationType == Match3ObservationType.Vector)
116126
{
@@ -119,22 +129,13 @@ public int Write(ObservationWriter writer)
119129
{
120130
for (var c = 0; c < m_Columns; c++)
121131
{
122-
var val = m_Board.GetCellType(r, c);
123-
for (var i = 0; i < m_NumCellTypes; i++)
132+
var val = m_GridValues(r, c);
133+
134+
for (var i = 0; i < m_OneHotSize; i++)
124135
{
125136
writer[offset] = (i == val) ? 1.0f : 0.0f;
126137
offset++;
127138
}
128-
129-
if (m_NumSpecialTypes > 0)
130-
{
131-
var special = m_Board.GetSpecialType(r, c);
132-
for (var i = 0; i < SpecialTypeSize; i++)
133-
{
134-
writer[offset] = (i == special) ? 1.0f : 0.0f;
135-
offset++;
136-
}
137-
}
138139
}
139140
}
140141

@@ -148,22 +149,12 @@ public int Write(ObservationWriter writer)
148149
{
149150
for (var c = 0; c < m_Columns; c++)
150151
{
151-
var val = m_Board.GetCellType(r, c);
152-
for (var i = 0; i < m_NumCellTypes; i++)
152+
var val = m_GridValues(r, c);
153+
for (var i = 0; i < m_OneHotSize; i++)
153154
{
154155
writer[r, c, i] = (i == val) ? 1.0f : 0.0f;
155156
offset++;
156157
}
157-
158-
if (m_NumSpecialTypes > 0)
159-
{
160-
var special = m_Board.GetSpecialType(r, c);
161-
for (var i = 0; i < SpecialTypeSize; i++)
162-
{
163-
writer[offset] = (i == special) ? 1.0f : 0.0f;
164-
offset++;
165-
}
166-
}
167158
}
168159
}
169160

@@ -185,17 +176,10 @@ public byte[] GetCompressedObservation()
185176
// fit in in 2 images, but we'll use 3 here (2 PNGs for the 4 cell type channels, and 1 for
186177
// the special types). Note that we have to also implement the sparse channel mapping.
187178
// Optimize this it later.
188-
var numCellImages = (m_NumCellTypes + 2) / 3;
179+
var numCellImages = (m_OneHotSize + 2) / 3;
189180
for (var i = 0; i < numCellImages; i++)
190181
{
191-
converter.EncodeToTexture(m_Board.GetCellType, tempTexture, 3 * i);
192-
bytesOut.AddRange(tempTexture.EncodeToPNG());
193-
}
194-
195-
var numSpecialImages = (SpecialTypeSize + 2) / 3;
196-
for (var i = 0; i < numSpecialImages; i++)
197-
{
198-
converter.EncodeToTexture(m_Board.GetSpecialType, tempTexture, 3 * i);
182+
converter.EncodeToTexture(m_GridValues, tempTexture, 3 * i);
199183
bytesOut.AddRange(tempTexture.EncodeToPNG());
200184
}
201185

@@ -223,7 +207,7 @@ internal SensorCompressionType GetCompressionType()
223207
/// <inheritdoc/>
224208
public CompressionSpec GetCompressionSpec()
225209
{
226-
return new CompressionSpec(GetCompressionType(), m_SparseChannelMapping);
210+
return new CompressionSpec(GetCompressionType());
227211
}
228212

229213
/// <inheritdoc/>
@@ -265,9 +249,6 @@ internal class OneHotToTextureUtil
265249
int m_Width;
266250
private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue };
267251

268-
public delegate int GridValueProvider(int x, int y);
269-
270-
271252
public OneHotToTextureUtil(int height, int width)
272253
{
273254
m_Colors = new Color[height * width];

com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,20 @@ public class Match3SensorComponent : SensorComponent
2121
public Match3ObservationType ObservationType = Match3ObservationType.Vector;
2222

2323
/// <inheritdoc/>
24-
public override ISensor CreateSensor()
24+
public override ISensor[] CreateSensors()
2525
{
2626
var board = GetComponent<AbstractBoard>();
27-
return new Match3Sensor(board, ObservationType, SensorName);
27+
var cellSensor = Match3Sensor.CellTypeSensor(board, ObservationType, SensorName + " (cells)");
28+
if (board.NumSpecialTypes > 0)
29+
{
30+
var specialSensor =
31+
Match3Sensor.SpecialTypeSensor(board, ObservationType, SensorName + " (special)");
32+
return new ISensor[] { cellSensor, specialSensor };
33+
}
34+
else
35+
{
36+
return new ISensor[] { cellSensor };
37+
}
2838
}
2939

3040
}

com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ public class ArticulationBodySensorComponent : SensorComponent
1616
/// Creates a PhysicsBodySensor.
1717
/// </summary>
1818
/// <returns></returns>
19-
public override ISensor CreateSensor()
19+
public override ISensor[] CreateSensors()
2020
{
21-
return new PhysicsBodySensor(RootBody, Settings, sensorName);
21+
return new ISensor[] {new PhysicsBodySensor(RootBody, Settings, sensorName)};
2222
}
2323

2424
}

com.unity.ml-agents.extensions/Runtime/Sensors/GridSensor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ public enum GridDepthType { Channel, ChannelHot };
267267
private Color DebugDefaultColor = new Color(1f, 1f, 1f, 0.25f);
268268

269269
/// <inheritdoc/>
270-
public override ISensor CreateSensor()
270+
public override ISensor[] CreateSensors()
271271
{
272-
return this;
272+
return new ISensor[] { this };
273273
}
274274

275275
/// <summary>

com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ public class RigidBodySensorComponent : SensorComponent
3939
/// Creates a PhysicsBodySensor.
4040
/// </summary>
4141
/// <returns></returns>
42-
public override ISensor CreateSensor()
42+
public override ISensor[] CreateSensors()
4343
{
4444
var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName;
45-
return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName);
45+
return new ISensor[] { new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName) };
4646
}
4747

4848
/// <summary>

0 commit comments

Comments
 (0)