4
4
5
5
namespace Unity . MLAgents . Extensions . Match3
6
6
{
7
+
8
+ /// <summary>
9
+ /// Delegate that provides integer values at a given (x,y) coordinate.
10
+ /// </summary>
11
+ /// <param name="x"></param>
12
+ /// <param name="y"></param>
13
+ public delegate int GridValueProvider ( int x , int y ) ;
14
+
7
15
/// <summary>
8
16
/// Type of observations to generate.
9
17
///
@@ -32,66 +40,68 @@ public enum Match3ObservationType
32
40
33
41
/// <summary>
34
42
/// Sensor for Match3 games. Can generate either vector, compressed visual,
35
- /// or uncompressed visual observations. Uses AbstractBoard.GetCellType()
36
- /// and AbstractBoard.GetSpecialType() to determine the observation values.
43
+ /// or uncompressed visual observations. Uses a GridValueProvider to determine the observation values.
37
44
/// </summary>
38
45
public class Match3Sensor : ISensor , IBuiltInSensor
39
46
{
40
47
private Match3ObservationType m_ObservationType ;
41
- private AbstractBoard m_Board ;
42
48
private ObservationSpec m_ObservationSpec ;
43
- private int [ ] m_SparseChannelMapping ;
44
49
private string m_Name ;
45
50
46
51
private int m_Rows ;
47
52
private int m_Columns ;
48
- private int m_NumCellTypes ;
49
- private int m_NumSpecialTypes ;
50
-
51
- private int SpecialTypeSize
52
- {
53
- get { return m_NumSpecialTypes == 0 ? 0 : m_NumSpecialTypes + 1 ; }
54
- }
53
+ private GridValueProvider m_GridValues ;
54
+ private int m_OneHotSize ;
55
55
56
56
/// <summary>
57
- /// Create a sensor for the board with the specified observation type.
57
+ /// Create a sensor for the GridValueProvider with the specified observation type.
58
58
/// </summary>
59
- /// <param name="board"></param>
60
- /// <param name="obsType"></param>
61
- /// <param name="name"></param>
62
- public Match3Sensor ( AbstractBoard board , Match3ObservationType obsType , string name )
59
+ /// <remarks>
60
+ /// Use Match3Sensor.CellTypeSensor() or Match3Sensor.SpecialTypeSensor() instead of calling
61
+ /// the constructor directly.
62
+ /// </remarks>
63
+ /// <param name="board">The abstract board. This is only used to get the size.</param>
64
+ /// <param name="gvp">The GridValueProvider, should be either board.GetCellType or board.GetSpecialType.</param>
65
+ /// <param name="oneHotSize">The number of possible values that the GridValueProvider can return.</param>
66
+ /// <param name="obsType">Whether to produce vector or visual observations</param>
67
+ /// <param name="name">Name of the sensor.</param>
68
+ public Match3Sensor ( AbstractBoard board , GridValueProvider gvp , int oneHotSize , Match3ObservationType obsType , string name )
63
69
{
64
- m_Board = board ;
65
70
m_Name = name ;
66
71
m_Rows = board . Rows ;
67
72
m_Columns = board . Columns ;
68
- m_NumCellTypes = board . NumCellTypes ;
69
- m_NumSpecialTypes = board . NumSpecialTypes ;
73
+ m_GridValues = gvp ;
74
+ m_OneHotSize = oneHotSize ;
70
75
71
76
m_ObservationType = obsType ;
72
77
m_ObservationSpec = obsType == Match3ObservationType . Vector
73
- ? ObservationSpec . Vector ( m_Rows * m_Columns * ( m_NumCellTypes + SpecialTypeSize ) )
74
- : ObservationSpec . Visual ( m_Rows , m_Columns , m_NumCellTypes + SpecialTypeSize ) ;
75
-
76
- // See comment in GetCompressedObservation()
77
- var cellTypePaddedSize = 3 * ( ( m_NumCellTypes + 2 ) / 3 ) ;
78
- m_SparseChannelMapping = new int [ cellTypePaddedSize + SpecialTypeSize ] ;
79
- // If we have 4 cell types and 2 special types (3 special size), we'd have
80
- // [0, 1, 2, 3, -1, -1, 4, 5, 6]
81
- for ( var i = 0 ; i < m_NumCellTypes ; i ++ )
82
- {
83
- m_SparseChannelMapping [ i ] = i ;
84
- }
78
+ ? ObservationSpec . Vector ( m_Rows * m_Columns * oneHotSize )
79
+ : ObservationSpec . Visual ( m_Rows , m_Columns , oneHotSize ) ;
80
+ }
85
81
86
- for ( var i = m_NumCellTypes ; i < cellTypePaddedSize ; i ++ )
87
- {
88
- m_SparseChannelMapping [ i ] = - 1 ;
89
- }
82
+ /// <summary>
83
+ /// Create a sensor that encodes the board cells as observations.
84
+ /// </summary>
85
+ /// <param name="board">The abstract board.</param>
86
+ /// <param name="obsType">Whether to produce vector or visual observations</param>
87
+ /// <param name="name">Name of the sensor.</param>
88
+ /// <returns></returns>
89
+ public static Match3Sensor CellTypeSensor ( AbstractBoard board , Match3ObservationType obsType , string name )
90
+ {
91
+ return new Match3Sensor ( board , board . GetCellType , board . NumCellTypes , obsType , name ) ;
92
+ }
90
93
91
- for ( var i = 0 ; i < SpecialTypeSize ; i ++ )
92
- {
93
- m_SparseChannelMapping [ cellTypePaddedSize + i ] = i + m_NumCellTypes ;
94
- }
94
+ /// <summary>
95
+ /// Create a sensor that encodes the cell special types as observations.
96
+ /// </summary>
97
+ /// <param name="board">The abstract board.</param>
98
+ /// <param name="obsType">Whether to produce vector or visual observations</param>
99
+ /// <param name="name">Name of the sensor.</param>
100
+ /// <returns></returns>
101
+ public static Match3Sensor SpecialTypeSensor ( AbstractBoard board , Match3ObservationType obsType , string name )
102
+ {
103
+ var specialSize = board . NumSpecialTypes == 0 ? 0 : board . NumSpecialTypes + 1 ;
104
+ return new Match3Sensor ( board , board . GetSpecialType , specialSize , obsType , name ) ;
95
105
}
96
106
97
107
/// <inheritdoc/>
@@ -103,14 +113,14 @@ public ObservationSpec GetObservationSpec()
103
113
/// <inheritdoc/>
104
114
public int Write ( ObservationWriter writer )
105
115
{
106
- if ( m_Board . Rows != m_Rows || m_Board . Columns != m_Columns || m_Board . NumCellTypes != m_NumCellTypes )
107
- {
108
- Debug . LogWarning (
109
- $ "Board shape changes since sensor initialization. This may cause unexpected results. " +
110
- $ "Old shape: Rows={ m_Rows } Columns={ m_Columns } , NumCellTypes={ m_NumCellTypes } " +
111
- $ "Current shape: Rows={ m_Board . Rows } Columns={ m_Board . Columns } , NumCellTypes={ m_Board . NumCellTypes } "
112
- ) ;
113
- }
116
+ // if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes)
117
+ // {
118
+ // Debug.LogWarning(
119
+ // $"Board shape changes since sensor initialization. This may cause unexpected results. " +
120
+ // $"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " +
121
+ // $"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}"
122
+ // );
123
+ // }
114
124
115
125
if ( m_ObservationType == Match3ObservationType . Vector )
116
126
{
@@ -119,22 +129,13 @@ public int Write(ObservationWriter writer)
119
129
{
120
130
for ( var c = 0 ; c < m_Columns ; c ++ )
121
131
{
122
- var val = m_Board . GetCellType ( r , c ) ;
123
- for ( var i = 0 ; i < m_NumCellTypes ; i ++ )
132
+ var val = m_GridValues ( r , c ) ;
133
+
134
+ for ( var i = 0 ; i < m_OneHotSize ; i ++ )
124
135
{
125
136
writer [ offset ] = ( i == val ) ? 1.0f : 0.0f ;
126
137
offset ++ ;
127
138
}
128
-
129
- if ( m_NumSpecialTypes > 0 )
130
- {
131
- var special = m_Board . GetSpecialType ( r , c ) ;
132
- for ( var i = 0 ; i < SpecialTypeSize ; i ++ )
133
- {
134
- writer [ offset ] = ( i == special ) ? 1.0f : 0.0f ;
135
- offset ++ ;
136
- }
137
- }
138
139
}
139
140
}
140
141
@@ -148,22 +149,12 @@ public int Write(ObservationWriter writer)
148
149
{
149
150
for ( var c = 0 ; c < m_Columns ; c ++ )
150
151
{
151
- var val = m_Board . GetCellType ( r , c ) ;
152
- for ( var i = 0 ; i < m_NumCellTypes ; i ++ )
152
+ var val = m_GridValues ( r , c ) ;
153
+ for ( var i = 0 ; i < m_OneHotSize ; i ++ )
153
154
{
154
155
writer [ r , c , i ] = ( i == val ) ? 1.0f : 0.0f ;
155
156
offset ++ ;
156
157
}
157
-
158
- if ( m_NumSpecialTypes > 0 )
159
- {
160
- var special = m_Board . GetSpecialType ( r , c ) ;
161
- for ( var i = 0 ; i < SpecialTypeSize ; i ++ )
162
- {
163
- writer [ offset ] = ( i == special ) ? 1.0f : 0.0f ;
164
- offset ++ ;
165
- }
166
- }
167
158
}
168
159
}
169
160
@@ -185,17 +176,10 @@ public byte[] GetCompressedObservation()
185
176
// fit in in 2 images, but we'll use 3 here (2 PNGs for the 4 cell type channels, and 1 for
186
177
// the special types). Note that we have to also implement the sparse channel mapping.
187
178
// Optimize this it later.
188
- var numCellImages = ( m_NumCellTypes + 2 ) / 3 ;
179
+ var numCellImages = ( m_OneHotSize + 2 ) / 3 ;
189
180
for ( var i = 0 ; i < numCellImages ; i ++ )
190
181
{
191
- converter . EncodeToTexture ( m_Board . GetCellType , tempTexture , 3 * i ) ;
192
- bytesOut . AddRange ( tempTexture . EncodeToPNG ( ) ) ;
193
- }
194
-
195
- var numSpecialImages = ( SpecialTypeSize + 2 ) / 3 ;
196
- for ( var i = 0 ; i < numSpecialImages ; i ++ )
197
- {
198
- converter . EncodeToTexture ( m_Board . GetSpecialType , tempTexture , 3 * i ) ;
182
+ converter . EncodeToTexture ( m_GridValues , tempTexture , 3 * i ) ;
199
183
bytesOut . AddRange ( tempTexture . EncodeToPNG ( ) ) ;
200
184
}
201
185
@@ -223,7 +207,7 @@ internal SensorCompressionType GetCompressionType()
223
207
/// <inheritdoc/>
224
208
public CompressionSpec GetCompressionSpec ( )
225
209
{
226
- return new CompressionSpec ( GetCompressionType ( ) , m_SparseChannelMapping ) ;
210
+ return new CompressionSpec ( GetCompressionType ( ) ) ;
227
211
}
228
212
229
213
/// <inheritdoc/>
@@ -265,9 +249,6 @@ internal class OneHotToTextureUtil
265
249
int m_Width ;
266
250
private static Color [ ] s_OneHotColors = { Color . red , Color . green , Color . blue } ;
267
251
268
- public delegate int GridValueProvider ( int x , int y ) ;
269
-
270
-
271
252
public OneHotToTextureUtil ( int height , int width )
272
253
{
273
254
m_Colors = new Color [ height * width ] ;
0 commit comments