@@ -29,6 +29,17 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
29
29
30
30
import testImplicits ._
31
31
32
+ def testRFormulaTransform [A : Encoder ](
33
+ dataframe : DataFrame ,
34
+ formulaModel : RFormulaModel ,
35
+ expected : DataFrame ): Unit = {
36
+ val (first +: rest) = expected.schema.fieldNames.toSeq
37
+ val expectedRows = expected.collect()
38
+ testTransformerByGlobalCheckFunc[A ](dataframe, formulaModel, first, rest : _* ) { rows =>
39
+ assert(rows === expectedRows)
40
+ }
41
+ }
42
+
32
43
test(" params" ) {
33
44
ParamsSuite .checkParams(new RFormula ())
34
45
}
@@ -47,7 +58,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
47
58
// TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
48
59
assert(result.schema.toString == resultSchema.toString)
49
60
assert(resultSchema == expected.schema)
50
- assert(result.collect() === expected.collect() )
61
+ testRFormulaTransform[( Int , Double , Double )](original, model, expected)
51
62
}
52
63
53
64
test(" features column already exists" ) {
@@ -109,7 +120,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
109
120
(7 , 8.0 , 9.0 , Vectors .dense(8.0 , 9.0 ))
110
121
).toDF(" id" , " a" , " b" , " features" )
111
122
assert(result.schema.toString == resultSchema.toString)
112
- assert(result.collect() === expected.collect() )
123
+ testRFormulaTransform[( Int , Double , Double )](original, model, expected)
113
124
}
114
125
115
126
test(" encodes string terms" ) {
@@ -126,7 +137,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
126
137
(4 , " baz" , 5 , Vectors .dense(0.0 , 0.0 , 5.0 ), 4.0 )
127
138
).toDF(" id" , " a" , " b" , " features" , " label" )
128
139
assert(result.schema.toString == resultSchema.toString)
129
- assert(result.collect() === expected.collect() )
140
+ testRFormulaTransform[( Int , String , Int )](original, model, expected)
130
141
}
131
142
132
143
test(" encodes string terms with string indexer order type" ) {
@@ -167,7 +178,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
167
178
val result = model.transform(original)
168
179
val resultSchema = model.transformSchema(original.schema)
169
180
assert(result.schema.toString == resultSchema.toString)
170
- assert(result.collect() === expected(idx).collect( ))
181
+ testRFormulaTransform[( Int , String , Int )](original, model, expected(idx))
171
182
idx += 1
172
183
}
173
184
}
@@ -210,7 +221,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
210
221
val result = model.transform(original)
211
222
val resultSchema = model.transformSchema(original.schema)
212
223
assert(result.schema.toString == resultSchema.toString)
213
- assert(result.collect() === expected.collect() )
224
+ testRFormulaTransform[( Int , String , Int )](original, model, expected)
214
225
}
215
226
216
227
test(" formula w/o intercept, we should output reference category when encoding string terms" ) {
@@ -253,7 +264,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
253
264
(4 , " baz" , " zz" , 5 , Vectors .dense(0.0 , 1.0 , 0.0 , 1.0 , 5.0 ), 4.0 )
254
265
).toDF(" id" , " a" , " b" , " c" , " features" , " label" )
255
266
assert(result1.schema.toString == resultSchema1.toString)
256
- assert(result1.collect() === expected1.collect() )
267
+ testRFormulaTransform[( Int , String , String , Int )](original, model1, expected1)
257
268
258
269
val attrs1 = AttributeGroup .fromStructField(result1.schema(" features" ))
259
270
val expectedAttrs1 = new AttributeGroup (
@@ -280,7 +291,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
280
291
(4 , " baz" , " zz" , 5 , Vectors .sparse(7 , Array (2 , 6 ), Array (1.0 , 5.0 )), 4.0 )
281
292
).toDF(" id" , " a" , " b" , " c" , " features" , " label" )
282
293
assert(result2.schema.toString == resultSchema2.toString)
283
- assert(result2.collect() === expected2.collect() )
294
+ testRFormulaTransform[( Int , String , String , Int )](original, model2, expected2)
284
295
285
296
val attrs2 = AttributeGroup .fromStructField(result2.schema(" features" ))
286
297
val expectedAttrs2 = new AttributeGroup (
@@ -302,15 +313,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
302
313
Seq ((" male" , " foo" , 4 ), (" female" , " bar" , 4 ), (" female" , " bar" , 5 ), (" male" , " baz" , 5 ))
303
314
.toDF(" id" , " a" , " b" )
304
315
val model = formula.fit(original)
305
- val result = model.transform(original)
306
316
val expected = Seq (
307
317
(" male" , " foo" , 4 , Vectors .dense(0.0 , 1.0 , 4.0 ), 1.0 ),
308
318
(" female" , " bar" , 4 , Vectors .dense(1.0 , 0.0 , 4.0 ), 0.0 ),
309
319
(" female" , " bar" , 5 , Vectors .dense(1.0 , 0.0 , 5.0 ), 0.0 ),
310
320
(" male" , " baz" , 5 , Vectors .dense(0.0 , 0.0 , 5.0 ), 1.0 )
311
321
).toDF(" id" , " a" , " b" , " features" , " label" )
312
322
// assert(result.schema.toString == resultSchema.toString)
313
- assert(result.collect() === expected.collect() )
323
+ testRFormulaTransform[( String , String , Int )](original, model, expected)
314
324
}
315
325
316
326
test(" force to index label even it is numeric type" ) {
@@ -319,15 +329,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
319
329
Seq ((1.0 , " foo" , 4 ), (1.0 , " bar" , 4 ), (0.0 , " bar" , 5 ), (1.0 , " baz" , 5 ))
320
330
).toDF(" id" , " a" , " b" )
321
331
val model = formula.fit(original)
322
- val result = model.transform(original)
323
332
val expected = spark.createDataFrame(
324
333
Seq (
325
334
(1.0 , " foo" , 4 , Vectors .dense(0.0 , 1.0 , 4.0 ), 0.0 ),
326
335
(1.0 , " bar" , 4 , Vectors .dense(1.0 , 0.0 , 4.0 ), 0.0 ),
327
336
(0.0 , " bar" , 5 , Vectors .dense(1.0 , 0.0 , 5.0 ), 1.0 ),
328
337
(1.0 , " baz" , 5 , Vectors .dense(0.0 , 0.0 , 5.0 ), 0.0 ))
329
338
).toDF(" id" , " a" , " b" , " features" , " label" )
330
- assert(result.collect() === expected.collect() )
339
+ testRFormulaTransform[( Double , String , Int )](original, model, expected)
331
340
}
332
341
333
342
test(" attribute generation" ) {
@@ -391,7 +400,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
391
400
(1 , 2 , 4 , 2 , Vectors .dense(16.0 ), 1.0 ),
392
401
(2 , 3 , 4 , 1 , Vectors .dense(12.0 ), 2.0 )
393
402
).toDF(" a" , " b" , " c" , " d" , " features" , " label" )
394
- assert(result.collect() === expected.collect() )
403
+ testRFormulaTransform[( Int , Int , Int , Int )](original, model, expected)
395
404
val attrs = AttributeGroup .fromStructField(result.schema(" features" ))
396
405
val expectedAttrs = new AttributeGroup (
397
406
" features" ,
@@ -414,7 +423,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
414
423
(4 , " baz" , 5 , Vectors .dense(5.0 , 0.0 , 0.0 ), 4.0 ),
415
424
(4 , " baz" , 5 , Vectors .dense(5.0 , 0.0 , 0.0 ), 4.0 )
416
425
).toDF(" id" , " a" , " b" , " features" , " label" )
417
- assert(result.collect() === expected.collect() )
426
+ testRFormulaTransform[( Int , String , Int )](original, model, expected)
418
427
val attrs = AttributeGroup .fromStructField(result.schema(" features" ))
419
428
val expectedAttrs = new AttributeGroup (
420
429
" features" ,
@@ -436,7 +445,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
436
445
(2 , " bar" , " zq" , Vectors .dense(1.0 , 0.0 , 0.0 , 0.0 ), 2.0 ),
437
446
(3 , " bar" , " zz" , Vectors .dense(0.0 , 1.0 , 0.0 , 0.0 ), 3.0 )
438
447
).toDF(" id" , " a" , " b" , " features" , " label" )
439
- assert(result.collect() === expected.collect() )
448
+ testRFormulaTransform[( Int , String , String )](original, model, expected)
440
449
val attrs = AttributeGroup .fromStructField(result.schema(" features" ))
441
450
val expectedAttrs = new AttributeGroup (
442
451
" features" ,
@@ -511,8 +520,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
511
520
intercept[SparkException ] {
512
521
formula1.fit(df1).transform(df2).collect()
513
522
}
514
- val result1 = formula1.setHandleInvalid(" skip" ).fit(df1).transform(df2 )
515
- val result2 = formula1.setHandleInvalid(" keep" ).fit(df1).transform(df2 )
523
+ val model1 = formula1.setHandleInvalid(" skip" ).fit(df1)
524
+ val model2 = formula1.setHandleInvalid(" keep" ).fit(df1)
516
525
517
526
val expected1 = Seq (
518
527
(1 , " foo" , " zq" , Vectors .dense(0.0 , 1.0 ), 1.0 ),
@@ -524,16 +533,16 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
524
533
(3 , " bar" , " zy" , Vectors .dense(1.0 , 0.0 , 0.0 , 0.0 ), 3.0 )
525
534
).toDF(" id" , " a" , " b" , " features" , " label" )
526
535
527
- assert(result1.collect() === expected1.collect() )
528
- assert(result2.collect() === expected2.collect() )
536
+ testRFormulaTransform[( Int , String , String )](df2, model1, expected1)
537
+ testRFormulaTransform[( Int , String , String )](df2, model2, expected2)
529
538
530
539
// Handle unseen labels.
531
540
val formula2 = new RFormula ().setFormula(" b ~ a + id" )
532
541
intercept[SparkException ] {
533
542
formula2.fit(df1).transform(df2).collect()
534
543
}
535
- val result3 = formula2.setHandleInvalid(" skip" ).fit(df1).transform(df2 )
536
- val result4 = formula2.setHandleInvalid(" keep" ).fit(df1).transform(df2 )
544
+ val model3 = formula2.setHandleInvalid(" skip" ).fit(df1)
545
+ val model4 = formula2.setHandleInvalid(" keep" ).fit(df1)
537
546
538
547
val expected3 = Seq (
539
548
(1 , " foo" , " zq" , Vectors .dense(0.0 , 1.0 ), 0.0 ),
@@ -545,8 +554,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
545
554
(3 , " bar" , " zy" , Vectors .dense(1.0 , 0.0 , 3.0 ), 2.0 )
546
555
).toDF(" id" , " a" , " b" , " features" , " label" )
547
556
548
- assert(result3.collect() === expected3.collect() )
549
- assert(result4.collect() === expected4.collect() )
557
+ testRFormulaTransform[( Int , String , String )](df2, model3, expected3)
558
+ testRFormulaTransform[( Int , String , String )](df2, model4, expected4)
550
559
}
551
560
552
561
test(" Use Vectors as inputs to formula." ) {
0 commit comments