RFormula
通过一个R model formula选择一个特定的列。
目前我们支持R
算子的一个受限的子集,包括~
,.
,:
,+
,-
。这些基本的算子是:
~
分开target
和terms
+
连接term
,+ 0
表示删除截距(intercept
)-
删除term
,- 1
表示删除截距:
交集.
除了target
之外的所有列
假设a
和b
是double
列,我们用下面简单的例子来证明RFormula
的有效性。
y ~ a + b
表示模型y ~ w0 + w1 * a + w2 * b
,其中w0
是截距,w1
和w2
是系数y ~ a + b + a:b - 1
表示模型y ~ w1 * a + w2 * b + w3 * a * b
,其中w1
,w2
,w3
是系数
RFormula
产生一个特征向量列和一个double
或string
类型的标签列。比如在线性回归中使用R
中的公式时,
字符串输入列是one-hot
编码,数值列强制转换为double
类型。如果标签列是字符串类型,它将使用StringIndexer
转换为double
类型。如果DataFrame
中不存在标签列,输出的标签列将通过公式中指定的返回变量来创建。
假设我们有一个DataFrame
,它的列名是id
, country
, hour
和clicked
。
id | country | hour | clicked
---|---------|------|---------
7 | "US" | 18 | 1.0
8 | "CA" | 12 | 0.0
9 | "NZ" | 15 | 0.0
如果我们用clicked ~ country + hour
(基于country
和hour
来预测clicked
)来作用于RFormula
,将会得到下面的结果。
id | country | hour | clicked | features | label
---|---------|------|---------|------------------|-------
7 | "US" | 18 | 1.0 | [0.0, 0.0, 18.0] | 1.0
8 | "CA" | 12 | 0.0 | [0.0, 1.0, 12.0] | 0.0
9 | "NZ" | 15 | 0.0 | [1.0, 0.0, 15.0] | 0.0
下面是代码调用的例子。
import org.apache.spark.ml.feature.RFormula
val dataset = spark.createDataFrame(Seq(
(7, "US", 18, 1.0),
(8, "CA", 12, 0.0),
(9, "NZ", 15, 0.0)
)).toDF("id", "country", "hour", "clicked")
val formula = new RFormula()
.setFormula("clicked ~ country + hour")
.setFeaturesCol("features")
.setLabelCol("label")
val output = formula.fit(dataset).transform(dataset)
output.select("features", "label").show()