Skip to content

Commit 4371466

Browse files
MrBagojkbradley
authored andcommitted
[SPARK-23045][ML][SPARKR] Update RFormula to use OneHotEncoderEstimator.
## What changes were proposed in this pull request? RFormula should use VectorSizeHint & OneHotEncoderEstimator in its pipeline to avoid using the deprecated OneHotEncoder & to ensure the model produced can be used in streaming. ## How was this patch tested? Unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian <bago@databricks.com> Closes #20229 from MrBago/rFormula.
1 parent 12db365 commit 4371466

File tree

3 files changed

+46
-28
lines changed

3 files changed

+46
-28
lines changed

R/pkg/R/mllib_utils.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,3 @@ read.ml <- function(path) {
130130
stop("Unsupported model: ", jobj)
131131
}
132132
}
133-

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
199199
val parsedFormula = RFormulaParser.parse($(formula))
200200
val resolvedFormula = parsedFormula.resolve(dataset.schema)
201201
val encoderStages = ArrayBuffer[PipelineStage]()
202+
val oneHotEncodeColumns = ArrayBuffer[(String, String)]()
202203

203204
val prefixesToRewrite = mutable.Map[String, String]()
204205
val tempColumns = ArrayBuffer[String]()
@@ -242,16 +243,17 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
242243
val encodedTerms = resolvedFormula.terms.map {
243244
case Seq(term) if dataset.schema(term).dataType == StringType =>
244245
val encodedCol = tmpColumn("onehot")
245-
var encoder = new OneHotEncoder()
246-
.setInputCol(indexed(term))
247-
.setOutputCol(encodedCol)
248246
// Formula w/o intercept, one of the categories in the first category feature is
249247
// being used as reference category, we will not drop any category for that feature.
250248
if (!hasIntercept && !keepReferenceCategory) {
251-
encoder = encoder.setDropLast(false)
249+
encoderStages += new OneHotEncoderEstimator(uid)
250+
.setInputCols(Array(indexed(term)))
251+
.setOutputCols(Array(encodedCol))
252+
.setDropLast(false)
252253
keepReferenceCategory = true
254+
} else {
255+
oneHotEncodeColumns += indexed(term) -> encodedCol
253256
}
254-
encoderStages += encoder
255257
prefixesToRewrite(encodedCol + "_") = term + "_"
256258
encodedCol
257259
case Seq(term) =>
@@ -265,6 +267,14 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
265267
interactionCol
266268
}
267269

270+
if (oneHotEncodeColumns.nonEmpty) {
271+
val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip
272+
encoderStages += new OneHotEncoderEstimator(uid)
273+
.setInputCols(inputCols)
274+
.setOutputCols(outputCols)
275+
.setDropLast(true)
276+
}
277+
268278
encoderStages += new VectorAssembler(uid)
269279
.setInputCols(encodedTerms.toArray)
270280
.setOutputCol($(featuresCol))

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
2929

3030
import testImplicits._
3131

32+
def testRFormulaTransform[A: Encoder](
33+
dataframe: DataFrame,
34+
formulaModel: RFormulaModel,
35+
expected: DataFrame): Unit = {
36+
val (first +: rest) = expected.schema.fieldNames.toSeq
37+
val expectedRows = expected.collect()
38+
testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows =>
39+
assert(rows === expectedRows)
40+
}
41+
}
42+
3243
test("params") {
3344
ParamsSuite.checkParams(new RFormula())
3445
}
@@ -47,7 +58,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
4758
// TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
4859
assert(result.schema.toString == resultSchema.toString)
4960
assert(resultSchema == expected.schema)
50-
assert(result.collect() === expected.collect())
61+
testRFormulaTransform[(Int, Double, Double)](original, model, expected)
5162
}
5263

5364
test("features column already exists") {
@@ -109,7 +120,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
109120
(7, 8.0, 9.0, Vectors.dense(8.0, 9.0))
110121
).toDF("id", "a", "b", "features")
111122
assert(result.schema.toString == resultSchema.toString)
112-
assert(result.collect() === expected.collect())
123+
testRFormulaTransform[(Int, Double, Double)](original, model, expected)
113124
}
114125

115126
test("encodes string terms") {
@@ -126,7 +137,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
126137
(4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
127138
).toDF("id", "a", "b", "features", "label")
128139
assert(result.schema.toString == resultSchema.toString)
129-
assert(result.collect() === expected.collect())
140+
testRFormulaTransform[(Int, String, Int)](original, model, expected)
130141
}
131142

132143
test("encodes string terms with string indexer order type") {
@@ -167,7 +178,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
167178
val result = model.transform(original)
168179
val resultSchema = model.transformSchema(original.schema)
169180
assert(result.schema.toString == resultSchema.toString)
170-
assert(result.collect() === expected(idx).collect())
181+
testRFormulaTransform[(Int, String, Int)](original, model, expected(idx))
171182
idx += 1
172183
}
173184
}
@@ -210,7 +221,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
210221
val result = model.transform(original)
211222
val resultSchema = model.transformSchema(original.schema)
212223
assert(result.schema.toString == resultSchema.toString)
213-
assert(result.collect() === expected.collect())
224+
testRFormulaTransform[(Int, String, Int)](original, model, expected)
214225
}
215226

216227
test("formula w/o intercept, we should output reference category when encoding string terms") {
@@ -253,7 +264,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
253264
(4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0)
254265
).toDF("id", "a", "b", "c", "features", "label")
255266
assert(result1.schema.toString == resultSchema1.toString)
256-
assert(result1.collect() === expected1.collect())
267+
testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1)
257268

258269
val attrs1 = AttributeGroup.fromStructField(result1.schema("features"))
259270
val expectedAttrs1 = new AttributeGroup(
@@ -280,7 +291,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
280291
(4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0)
281292
).toDF("id", "a", "b", "c", "features", "label")
282293
assert(result2.schema.toString == resultSchema2.toString)
283-
assert(result2.collect() === expected2.collect())
294+
testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2)
284295

285296
val attrs2 = AttributeGroup.fromStructField(result2.schema("features"))
286297
val expectedAttrs2 = new AttributeGroup(
@@ -302,15 +313,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
302313
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
303314
.toDF("id", "a", "b")
304315
val model = formula.fit(original)
305-
val result = model.transform(original)
306316
val expected = Seq(
307317
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
308318
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
309319
("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0),
310320
("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)
311321
).toDF("id", "a", "b", "features", "label")
312322
// assert(result.schema.toString == resultSchema.toString)
313-
assert(result.collect() === expected.collect())
323+
testRFormulaTransform[(String, String, Int)](original, model, expected)
314324
}
315325

316326
test("force to index label even it is numeric type") {
@@ -319,15 +329,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
319329
Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5))
320330
).toDF("id", "a", "b")
321331
val model = formula.fit(original)
322-
val result = model.transform(original)
323332
val expected = spark.createDataFrame(
324333
Seq(
325334
(1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0),
326335
(1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
327336
(0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0),
328337
(1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0))
329338
).toDF("id", "a", "b", "features", "label")
330-
assert(result.collect() === expected.collect())
339+
testRFormulaTransform[(Double, String, Int)](original, model, expected)
331340
}
332341

333342
test("attribute generation") {
@@ -391,7 +400,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
391400
(1, 2, 4, 2, Vectors.dense(16.0), 1.0),
392401
(2, 3, 4, 1, Vectors.dense(12.0), 2.0)
393402
).toDF("a", "b", "c", "d", "features", "label")
394-
assert(result.collect() === expected.collect())
403+
testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected)
395404
val attrs = AttributeGroup.fromStructField(result.schema("features"))
396405
val expectedAttrs = new AttributeGroup(
397406
"features",
@@ -414,7 +423,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
414423
(4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
415424
(4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)
416425
).toDF("id", "a", "b", "features", "label")
417-
assert(result.collect() === expected.collect())
426+
testRFormulaTransform[(Int, String, Int)](original, model, expected)
418427
val attrs = AttributeGroup.fromStructField(result.schema("features"))
419428
val expectedAttrs = new AttributeGroup(
420429
"features",
@@ -436,7 +445,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
436445
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
437446
(3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0)
438447
).toDF("id", "a", "b", "features", "label")
439-
assert(result.collect() === expected.collect())
448+
testRFormulaTransform[(Int, String, String)](original, model, expected)
440449
val attrs = AttributeGroup.fromStructField(result.schema("features"))
441450
val expectedAttrs = new AttributeGroup(
442451
"features",
@@ -511,8 +520,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
511520
intercept[SparkException] {
512521
formula1.fit(df1).transform(df2).collect()
513522
}
514-
val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2)
515-
val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2)
523+
val model1 = formula1.setHandleInvalid("skip").fit(df1)
524+
val model2 = formula1.setHandleInvalid("keep").fit(df1)
516525

517526
val expected1 = Seq(
518527
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0),
@@ -524,16 +533,16 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
524533
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0)
525534
).toDF("id", "a", "b", "features", "label")
526535

527-
assert(result1.collect() === expected1.collect())
528-
assert(result2.collect() === expected2.collect())
536+
testRFormulaTransform[(Int, String, String)](df2, model1, expected1)
537+
testRFormulaTransform[(Int, String, String)](df2, model2, expected2)
529538

530539
// Handle unseen labels.
531540
val formula2 = new RFormula().setFormula("b ~ a + id")
532541
intercept[SparkException] {
533542
formula2.fit(df1).transform(df2).collect()
534543
}
535-
val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2)
536-
val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2)
544+
val model3 = formula2.setHandleInvalid("skip").fit(df1)
545+
val model4 = formula2.setHandleInvalid("keep").fit(df1)
537546

538547
val expected3 = Seq(
539548
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0),
@@ -545,8 +554,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
545554
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0)
546555
).toDF("id", "a", "b", "features", "label")
547556

548-
assert(result3.collect() === expected3.collect())
549-
assert(result4.collect() === expected4.collect())
557+
testRFormulaTransform[(Int, String, String)](df2, model3, expected3)
558+
testRFormulaTransform[(Int, String, String)](df2, model4, expected4)
550559
}
551560

552561
test("Use Vectors as inputs to formula.") {

0 commit comments

Comments
 (0)