Skip to content

Commit f47700c

Browse files
Wayne Zhangyanboliang
authored andcommitted
[SPARK-14659][ML] RFormula consistent with R when handling strings
## What changes were proposed in this pull request? When handling strings, the category dropped by RFormula and R are different: - RFormula drops the least frequent level - R drops the first level after ascending alphabetical ordering This PR supports different string ordering types in StringIndexer #17879 so that RFormula can drop the same level as R when handling strings using`stringOrderType = "alphabetDesc"`. ## How was this patch tested? new tests Author: Wayne Zhang <actuaryzhang@uber.com> Closes #17967 from actuaryzhang/RFormula.
1 parent 2dbe0c5 commit f47700c

File tree

3 files changed

+129
-3
lines changed

3 files changed

+129
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.{Experimental, Since}
2626
import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
2727
import org.apache.spark.ml.attribute.AttributeGroup
2828
import org.apache.spark.ml.linalg.VectorUDT
29-
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap}
29+
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
3030
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
3131
import org.apache.spark.ml.util._
3232
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -37,6 +37,42 @@ import org.apache.spark.sql.types._
3737
*/
3838
private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
3939

40+
/**
41+
* Param for how to order categories of a string FEATURE column used by `StringIndexer`.
42+
* The last category after ordering is dropped when encoding strings.
43+
* Supported options: 'frequencyDesc', 'frequencyAsc', 'alphabetDesc', 'alphabetAsc'.
44+
* The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', `RFormula`
45+
* drops the same category as R when encoding strings.
46+
*
47+
* The options are explained using an example `'b', 'a', 'b', 'a', 'c', 'b'`:
48+
* {{{
49+
* +-----------------+---------------------------------------+----------------------------------+
50+
* | Option | Category mapped to 0 by StringIndexer | Category dropped by RFormula |
51+
* +-----------------+---------------------------------------+----------------------------------+
52+
* | 'frequencyDesc' | most frequent category ('b') | least frequent category ('c') |
53+
* | 'frequencyAsc' | least frequent category ('c') | most frequent category ('b') |
54+
* | 'alphabetDesc' | last alphabetical category ('c') | first alphabetical category ('a')|
55+
* | 'alphabetAsc' | first alphabetical category ('a') | last alphabetical category ('c') |
56+
* +-----------------+---------------------------------------+----------------------------------+
57+
* }}}
58+
* Note that this ordering option is NOT used for the label column. When the label column is
59+
* indexed, it uses the default descending frequency ordering in `StringIndexer`.
60+
*
61+
* @group param
62+
*/
63+
@Since("2.3.0")
64+
final val stringIndexerOrderType: Param[String] = new Param(this, "stringIndexerOrderType",
65+
"How to order categories of a string FEATURE column used by StringIndexer. " +
66+
"The last category after ordering is dropped when encoding strings. " +
67+
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}. " +
68+
"The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', " +
69+
"RFormula drops the same category as R when encoding strings.",
70+
ParamValidators.inArray(StringIndexer.supportedStringOrderType))
71+
72+
/** @group getParam */
73+
@Since("2.3.0")
74+
def getStringIndexerOrderType: String = $(stringIndexerOrderType)
75+
4076
protected def hasLabelCol(schema: StructType): Boolean = {
4177
schema.map(_.name).contains($(labelCol))
4278
}
@@ -125,6 +161,11 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
125161
@Since("2.1.0")
126162
def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value)
127163

164+
/** @group setParam */
165+
@Since("2.3.0")
166+
def setStringIndexerOrderType(value: String): this.type = set(stringIndexerOrderType, value)
167+
setDefault(stringIndexerOrderType, StringIndexer.frequencyDesc)
168+
128169
/** Whether the formula specifies fitting an intercept. */
129170
private[ml] def hasIntercept: Boolean = {
130171
require(isDefined(formula), "Formula must be defined first.")
@@ -155,6 +196,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
155196
encoderStages += new StringIndexer()
156197
.setInputCol(term)
157198
.setOutputCol(indexCol)
199+
.setStringOrderType($(stringIndexerOrderType))
158200
prefixesToRewrite(indexCol + "_") = term + "_"
159201
(term, indexCol)
160202
case _ =>

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
4747
* @group param
4848
*/
4949
@Since("1.6.0")
50-
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
50+
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " +
5151
"invalid data (unseen labels or NULL values). " +
5252
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
5353
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
@@ -73,7 +73,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
7373
*/
7474
@Since("2.3.0")
7575
final val stringOrderType: Param[String] = new Param(this, "stringOrderType",
76-
"how to order labels of string column. " +
76+
"How to order labels of string column. " +
7777
"The first label after ordering is assigned an index of 0. " +
7878
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.",
7979
ParamValidators.inArray(StringIndexer.supportedStringOrderType))

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,90 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
129129
assert(result.collect() === expected.collect())
130130
}
131131

132+
test("encodes string terms with string indexer order type") {
133+
val formula = new RFormula().setFormula("id ~ a + b")
134+
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "aaz", 5))
135+
.toDF("id", "a", "b")
136+
137+
val expected = Seq(
138+
Seq(
139+
(1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
140+
(2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
141+
(3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
142+
(4, "aaz", 5, Vectors.dense(0.0, 1.0, 5.0), 4.0)
143+
).toDF("id", "a", "b", "features", "label"),
144+
Seq(
145+
(1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
146+
(2, "bar", 4, Vectors.dense(0.0, 0.0, 4.0), 2.0),
147+
(3, "bar", 5, Vectors.dense(0.0, 0.0, 5.0), 3.0),
148+
(4, "aaz", 5, Vectors.dense(1.0, 0.0, 5.0), 4.0)
149+
).toDF("id", "a", "b", "features", "label"),
150+
Seq(
151+
(1, "foo", 4, Vectors.dense(1.0, 0.0, 4.0), 1.0),
152+
(2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0),
153+
(3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0),
154+
(4, "aaz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
155+
).toDF("id", "a", "b", "features", "label"),
156+
Seq(
157+
(1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
158+
(2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0),
159+
(3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0),
160+
(4, "aaz", 5, Vectors.dense(1.0, 0.0, 5.0), 4.0)
161+
).toDF("id", "a", "b", "features", "label")
162+
)
163+
164+
var idx = 0
165+
for (orderType <- StringIndexer.supportedStringOrderType) {
166+
val model = formula.setStringIndexerOrderType(orderType).fit(original)
167+
val result = model.transform(original)
168+
val resultSchema = model.transformSchema(original.schema)
169+
assert(result.schema.toString == resultSchema.toString)
170+
assert(result.collect() === expected(idx).collect())
171+
idx += 1
172+
}
173+
}
174+
175+
test("test consistency with R when encoding string terms") {
176+
/*
177+
R code:
178+
179+
df <- data.frame(id = c(1, 2, 3, 4),
180+
a = c("foo", "bar", "bar", "aaz"),
181+
b = c(4, 4, 5, 5))
182+
model.matrix(id ~ a + b, df)[, -1]
183+
184+
abar afoo b
185+
0 1 4
186+
1 0 4
187+
1 0 5
188+
0 0 5
189+
*/
190+
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "aaz", 5))
191+
.toDF("id", "a", "b")
192+
val formula = new RFormula().setFormula("id ~ a + b")
193+
.setStringIndexerOrderType(StringIndexer.alphabetDesc)
194+
195+
/*
196+
Note that the category dropped after encoding is the same between R and Spark
197+
(i.e., "aaz" is treated as the reference level).
198+
However, the column order is still different:
199+
R renders the columns in ascending alphabetical order ("bar", "foo"), while
200+
RFormula renders the columns in descending alphabetical order ("foo", "bar").
201+
*/
202+
val expected = Seq(
203+
(1, "foo", 4, Vectors.dense(1.0, 0.0, 4.0), 1.0),
204+
(2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0),
205+
(3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0),
206+
(4, "aaz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
207+
).toDF("id", "a", "b", "features", "label")
208+
209+
val model = formula.fit(original)
210+
val result = model.transform(original)
211+
val resultSchema = model.transformSchema(original.schema)
212+
assert(result.schema.toString == resultSchema.toString)
213+
assert(result.collect() === expected.collect())
214+
}
215+
132216
test("index string label") {
133217
val formula = new RFormula().setFormula("id ~ a + b")
134218
val original =

0 commit comments

Comments
 (0)