@@ -129,6 +129,90 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
129
129
assert(result.collect() === expected.collect())
130
130
}
131
131
132
+ test(" encodes string terms with string indexer order type" ) {
133
+ val formula = new RFormula ().setFormula(" id ~ a + b" )
134
+ val original = Seq ((1 , " foo" , 4 ), (2 , " bar" , 4 ), (3 , " bar" , 5 ), (4 , " aaz" , 5 ))
135
+ .toDF(" id" , " a" , " b" )
136
+
137
+ val expected = Seq (
138
+ Seq (
139
+ (1 , " foo" , 4 , Vectors .dense(0.0 , 0.0 , 4.0 ), 1.0 ),
140
+ (2 , " bar" , 4 , Vectors .dense(1.0 , 0.0 , 4.0 ), 2.0 ),
141
+ (3 , " bar" , 5 , Vectors .dense(1.0 , 0.0 , 5.0 ), 3.0 ),
142
+ (4 , " aaz" , 5 , Vectors .dense(0.0 , 1.0 , 5.0 ), 4.0 )
143
+ ).toDF(" id" , " a" , " b" , " features" , " label" ),
144
+ Seq (
145
+ (1 , " foo" , 4 , Vectors .dense(0.0 , 1.0 , 4.0 ), 1.0 ),
146
+ (2 , " bar" , 4 , Vectors .dense(0.0 , 0.0 , 4.0 ), 2.0 ),
147
+ (3 , " bar" , 5 , Vectors .dense(0.0 , 0.0 , 5.0 ), 3.0 ),
148
+ (4 , " aaz" , 5 , Vectors .dense(1.0 , 0.0 , 5.0 ), 4.0 )
149
+ ).toDF(" id" , " a" , " b" , " features" , " label" ),
150
+ Seq (
151
+ (1 , " foo" , 4 , Vectors .dense(1.0 , 0.0 , 4.0 ), 1.0 ),
152
+ (2 , " bar" , 4 , Vectors .dense(0.0 , 1.0 , 4.0 ), 2.0 ),
153
+ (3 , " bar" , 5 , Vectors .dense(0.0 , 1.0 , 5.0 ), 3.0 ),
154
+ (4 , " aaz" , 5 , Vectors .dense(0.0 , 0.0 , 5.0 ), 4.0 )
155
+ ).toDF(" id" , " a" , " b" , " features" , " label" ),
156
+ Seq (
157
+ (1 , " foo" , 4 , Vectors .dense(0.0 , 0.0 , 4.0 ), 1.0 ),
158
+ (2 , " bar" , 4 , Vectors .dense(0.0 , 1.0 , 4.0 ), 2.0 ),
159
+ (3 , " bar" , 5 , Vectors .dense(0.0 , 1.0 , 5.0 ), 3.0 ),
160
+ (4 , " aaz" , 5 , Vectors .dense(1.0 , 0.0 , 5.0 ), 4.0 )
161
+ ).toDF(" id" , " a" , " b" , " features" , " label" )
162
+ )
163
+
164
+ var idx = 0
165
+ for (orderType <- StringIndexer .supportedStringOrderType) {
166
+ val model = formula.setStringIndexerOrderType(orderType).fit(original)
167
+ val result = model.transform(original)
168
+ val resultSchema = model.transformSchema(original.schema)
169
+ assert(result.schema.toString == resultSchema.toString)
170
+ assert(result.collect() === expected(idx).collect())
171
+ idx += 1
172
+ }
173
+ }
174
+
175
+ test(" test consistency with R when encoding string terms" ) {
176
+ /*
177
+ R code:
178
+
179
+ df <- data.frame(id = c(1, 2, 3, 4),
180
+ a = c("foo", "bar", "bar", "aaz"),
181
+ b = c(4, 4, 5, 5))
182
+ model.matrix(id ~ a + b, df)[, -1]
183
+
184
+ abar afoo b
185
+ 0 1 4
186
+ 1 0 4
187
+ 1 0 5
188
+ 0 0 5
189
+ */
190
+ val original = Seq ((1 , " foo" , 4 ), (2 , " bar" , 4 ), (3 , " bar" , 5 ), (4 , " aaz" , 5 ))
191
+ .toDF(" id" , " a" , " b" )
192
+ val formula = new RFormula ().setFormula(" id ~ a + b" )
193
+ .setStringIndexerOrderType(StringIndexer .alphabetDesc)
194
+
195
+ /*
196
+ Note that the category dropped after encoding is the same between R and Spark
197
+ (i.e., "aaz" is treated as the reference level).
198
+ However, the column order is still different:
199
+ R renders the columns in ascending alphabetical order ("bar", "foo"), while
200
+ RFormula renders the columns in descending alphabetical order ("foo", "bar").
201
+ */
202
+ val expected = Seq (
203
+ (1 , " foo" , 4 , Vectors .dense(1.0 , 0.0 , 4.0 ), 1.0 ),
204
+ (2 , " bar" , 4 , Vectors .dense(0.0 , 1.0 , 4.0 ), 2.0 ),
205
+ (3 , " bar" , 5 , Vectors .dense(0.0 , 1.0 , 5.0 ), 3.0 ),
206
+ (4 , " aaz" , 5 , Vectors .dense(0.0 , 0.0 , 5.0 ), 4.0 )
207
+ ).toDF(" id" , " a" , " b" , " features" , " label" )
208
+
209
+ val model = formula.fit(original)
210
+ val result = model.transform(original)
211
+ val resultSchema = model.transformSchema(original.schema)
212
+ assert(result.schema.toString == resultSchema.toString)
213
+ assert(result.collect() === expected.collect())
214
+ }
215
+
132
216
test(" index string label" ) {
133
217
val formula = new RFormula ().setFormula(" id ~ a + b" )
134
218
val original =
0 commit comments