Skip to content

[SPARK-23045][ML][SparkR] Update RFormula to use OneHotEncoderEstimator. #20229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion R/pkg/R/mllib_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,3 @@ read.ml <- function(path) {
stop("Unsupported model: ", jobj)
}
}

20 changes: 15 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
val parsedFormula = RFormulaParser.parse($(formula))
val resolvedFormula = parsedFormula.resolve(dataset.schema)
val encoderStages = ArrayBuffer[PipelineStage]()
val oneHotEncodeColumns = ArrayBuffer[(String, String)]()

val prefixesToRewrite = mutable.Map[String, String]()
val tempColumns = ArrayBuffer[String]()
Expand Down Expand Up @@ -242,16 +243,17 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
val encodedTerms = resolvedFormula.terms.map {
case Seq(term) if dataset.schema(term).dataType == StringType =>
val encodedCol = tmpColumn("onehot")
var encoder = new OneHotEncoder()
.setInputCol(indexed(term))
.setOutputCol(encodedCol)
// Formula w/o intercept, one of the categories in the first category feature is
// being used as reference category, we will not drop any category for that feature.
if (!hasIntercept && !keepReferenceCategory) {
encoder = encoder.setDropLast(false)
encoderStages += new OneHotEncoderEstimator(uid)
.setInputCols(Array(indexed(term)))
.setOutputCols(Array(encodedCol))
.setDropLast(false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here can optimize. You can merge this multiple (probable) OHEs into one. like:
define:

val oneHotEncodeColumnsNotDropLast = ArrayBuffer[(String, String)]()

and:

if (!hasIntercept && !keepReferenceCategory) {
    oneHotEncodeColumnsNotDropLast += indexed(term) -> encodedCol
} else {
   oneHotEncodeColumns += indexed(term) -> encodedCol
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is at most 1 encoder with dropLast(false), the next line sets keepReferenceCategory = true to ensure we won't take this code path for the remaining columns.

keepReferenceCategory = true
} else {
oneHotEncodeColumns += indexed(term) -> encodedCol
}
encoderStages += encoder
prefixesToRewrite(encodedCol + "_") = term + "_"
encodedCol
case Seq(term) =>
Expand All @@ -265,6 +267,14 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
interactionCol
}

if (oneHotEncodeColumns.nonEmpty) {
val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip
encoderStages += new OneHotEncoderEstimator(uid)
.setInputCols(inputCols)
.setOutputCols(outputCols)
.setDropLast(true)
}

encoderStages += new VectorAssembler(uid)
.setInputCols(encodedTerms.toArray)
.setOutputCol($(featuresCol))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {

import testImplicits._

def testRFormulaTransform[A: Encoder](
dataframe: DataFrame,
formulaModel: RFormulaModel,
expected: DataFrame): Unit = {
val (first +: rest) = expected.schema.fieldNames.toSeq
val expectedRows = expected.collect()
testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows =>
assert(rows === expectedRows)
}
}

test("params") {
ParamsSuite.checkParams(new RFormula())
}
Expand All @@ -47,7 +58,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
// TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
assert(result.schema.toString == resultSchema.toString)
assert(resultSchema == expected.schema)
assert(result.collect() === expected.collect())
testRFormulaTransform[(Int, Double, Double)](original, model, expected)
}

test("features column already exists") {
Expand Down Expand Up @@ -109,7 +120,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(7, 8.0, 9.0, Vectors.dense(8.0, 9.0))
).toDF("id", "a", "b", "features")
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
testRFormulaTransform[(Int, Double, Double)](original, model, expected)
}

test("encodes string terms") {
Expand All @@ -126,7 +137,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
).toDF("id", "a", "b", "features", "label")
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
testRFormulaTransform[(Int, String, Int)](original, model, expected)
}

test("encodes string terms with string indexer order type") {
Expand Down Expand Up @@ -167,7 +178,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected(idx).collect())
testRFormulaTransform[(Int, String, Int)](original, model, expected(idx))
idx += 1
}
}
Expand Down Expand Up @@ -210,7 +221,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
testRFormulaTransform[(Int, String, Int)](original, model, expected)
}

test("formula w/o intercept, we should output reference category when encoding string terms") {
Expand Down Expand Up @@ -253,7 +264,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0)
).toDF("id", "a", "b", "c", "features", "label")
assert(result1.schema.toString == resultSchema1.toString)
assert(result1.collect() === expected1.collect())
testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1)

val attrs1 = AttributeGroup.fromStructField(result1.schema("features"))
val expectedAttrs1 = new AttributeGroup(
Expand All @@ -280,7 +291,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0)
).toDF("id", "a", "b", "c", "features", "label")
assert(result2.schema.toString == resultSchema2.toString)
assert(result2.collect() === expected2.collect())
testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2)

val attrs2 = AttributeGroup.fromStructField(result2.schema("features"))
val expectedAttrs2 = new AttributeGroup(
Expand All @@ -302,15 +313,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
.toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val expected = Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0),
("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)
).toDF("id", "a", "b", "features", "label")
// assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
testRFormulaTransform[(String, String, Int)](original, model, expected)
}

test("force to index label even it is numeric type") {
Expand All @@ -319,15 +329,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val expected = spark.createDataFrame(
Seq(
(1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0),
(1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
(0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0),
(1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0))
).toDF("id", "a", "b", "features", "label")
assert(result.collect() === expected.collect())
testRFormulaTransform[(Double, String, Int)](original, model, expected)
}

test("attribute generation") {
Expand Down Expand Up @@ -391,7 +400,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(1, 2, 4, 2, Vectors.dense(16.0), 1.0),
(2, 3, 4, 1, Vectors.dense(12.0), 2.0)
).toDF("a", "b", "c", "d", "features", "label")
assert(result.collect() === expected.collect())
testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Expand All @@ -414,7 +423,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
(4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)
).toDF("id", "a", "b", "features", "label")
assert(result.collect() === expected.collect())
testRFormulaTransform[(Int, String, Int)](original, model, expected)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Expand All @@ -436,7 +445,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
(3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0)
).toDF("id", "a", "b", "features", "label")
assert(result.collect() === expected.collect())
testRFormulaTransform[(Int, String, String)](original, model, expected)
val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Expand Down Expand Up @@ -511,8 +520,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
intercept[SparkException] {
formula1.fit(df1).transform(df2).collect()
}
val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2)
val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2)
val model1 = formula1.setHandleInvalid("skip").fit(df1)
val model2 = formula1.setHandleInvalid("keep").fit(df1)

val expected1 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0),
Expand All @@ -524,16 +533,16 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0)
).toDF("id", "a", "b", "features", "label")

assert(result1.collect() === expected1.collect())
assert(result2.collect() === expected2.collect())
testRFormulaTransform[(Int, String, String)](df2, model1, expected1)
testRFormulaTransform[(Int, String, String)](df2, model2, expected2)

// Handle unseen labels.
val formula2 = new RFormula().setFormula("b ~ a + id")
intercept[SparkException] {
formula2.fit(df1).transform(df2).collect()
}
val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2)
val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2)
val model3 = formula2.setHandleInvalid("skip").fit(df1)
val model4 = formula2.setHandleInvalid("keep").fit(df1)

val expected3 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0),
Expand All @@ -545,8 +554,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0)
).toDF("id", "a", "b", "features", "label")

assert(result3.collect() === expected3.collect())
assert(result4.collect() === expected4.collect())
testRFormulaTransform[(Int, String, String)](df2, model3, expected3)
testRFormulaTransform[(Int, String, String)](df2, model4, expected4)
}

test("Use Vectors as inputs to formula.") {
Expand Down