Skip to content

Commit 48a0fff

Browse files
committed
separate Model from StandardScaler algorithm
1 parent 89f3486 commit 48a0fff

File tree

2 files changed

+43
-41
lines changed

2 files changed

+43
-41
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,38 +35,47 @@ import org.apache.spark.rdd.RDD
3535
* @param withStd True by default. Scales the data to unit standard deviation.
3636
*/
3737
@Experimental
38-
class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer {
38+
class StandardScaler(withMean: Boolean, withStd: Boolean) {
3939

4040
def this() = this(false, true)
4141

4242
require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.")
4343

44-
private var mean: BV[Double] = _
45-
private var factor: BV[Double] = _
46-
4744
/**
4845
* Computes the mean and variance and stores as a model to be used for later scaling.
4946
*
5047
* @param data The data used to compute the mean and variance to build the transformation model.
51-
* @return This StandardScalar object.
48+
* @return a StandardScalarModel
5249
*/
53-
def fit(data: RDD[Vector]): this.type = {
50+
def fit(data: RDD[Vector]): StandardScalerModel = {
5451
val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
5552
(aggregator, data) => aggregator.add(data),
5653
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
5754

58-
mean = summary.mean.toBreeze
59-
factor = summary.variance.toBreeze
60-
require(mean.length == factor.length)
55+
val mean = summary.mean.toBreeze
56+
val factor = summary.variance.toBreeze
57+
require(mean.size == factor.size)
6158

6259
var i = 0
63-
while (i < factor.length) {
60+
while (i < factor.size) {
6461
factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0
6562
i += 1
6663
}
6764

68-
this
65+
new StandardScalerModel(withMean, withStd, mean, factor)
6966
}
67+
}
68+
69+
/**
70+
* :: Experimental ::
71+
* Represents a StandardScaler model that can transform vectors.
72+
*/
73+
@Experimental
74+
class StandardScalerModel private[mllib] (
75+
val withMean: Boolean,
76+
val withStd: Boolean,
77+
val mean: BV[Double],
78+
val factor: BV[Double]) extends VectorTransformer {
7079

7180
/**
7281
* Applies standardization transformation on a vector.
@@ -81,7 +90,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor
8190
"Haven't learned column summary statistics yet. Call fit first.")
8291
}
8392

84-
require(vector.size == mean.length)
93+
require(vector.size == mean.size)
8594

8695
if (withMean) {
8796
vector.toBreeze match {
@@ -115,5 +124,4 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor
115124
vector
116125
}
117126
}
118-
119127
}

mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,17 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext {
5050
val standardizer2 = new StandardScaler()
5151
val standardizer3 = new StandardScaler(withMean = true, withStd = false)
5252

53-
withClue("Using a standardizer before fitting the model should throw exception.") {
54-
intercept[IllegalStateException] {
55-
data.map(standardizer1.transform)
56-
}
57-
}
58-
59-
standardizer1.fit(dataRDD)
60-
standardizer2.fit(dataRDD)
61-
standardizer3.fit(dataRDD)
53+
val model1 = standardizer1.fit(dataRDD)
54+
val model2 = standardizer2.fit(dataRDD)
55+
val model3 = standardizer3.fit(dataRDD)
6256

63-
val data1 = data.map(standardizer1.transform)
64-
val data2 = data.map(standardizer2.transform)
65-
val data3 = data.map(standardizer3.transform)
57+
val data1 = data.map(model1.transform)
58+
val data2 = data.map(model2.transform)
59+
val data3 = data.map(model3.transform)
6660

67-
val data1RDD = standardizer1.transform(dataRDD)
68-
val data2RDD = standardizer2.transform(dataRDD)
69-
val data3RDD = standardizer3.transform(dataRDD)
61+
val data1RDD = model1.transform(dataRDD)
62+
val data2RDD = model2.transform(dataRDD)
63+
val data3RDD = model3.transform(dataRDD)
7064

7165
val summary = computeSummary(dataRDD)
7266
val summary1 = computeSummary(data1RDD)
@@ -129,25 +123,25 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext {
129123
val standardizer2 = new StandardScaler()
130124
val standardizer3 = new StandardScaler(withMean = true, withStd = false)
131125

132-
standardizer1.fit(dataRDD)
133-
standardizer2.fit(dataRDD)
134-
standardizer3.fit(dataRDD)
126+
val model1 = standardizer1.fit(dataRDD)
127+
val model2 = standardizer2.fit(dataRDD)
128+
val model3 = standardizer3.fit(dataRDD)
135129

136-
val data2 = data.map(standardizer2.transform)
130+
val data2 = data.map(model2.transform)
137131

138132
withClue("Standardization with mean can not be applied on sparse input.") {
139133
intercept[IllegalArgumentException] {
140-
data.map(standardizer1.transform)
134+
data.map(model1.transform)
141135
}
142136
}
143137

144138
withClue("Standardization with mean can not be applied on sparse input.") {
145139
intercept[IllegalArgumentException] {
146-
data.map(standardizer3.transform)
140+
data.map(model3.transform)
147141
}
148142
}
149143

150-
val data2RDD = standardizer2.transform(dataRDD)
144+
val data2RDD = model2.transform(dataRDD)
151145

152146
val summary2 = computeSummary(data2RDD)
153147

@@ -181,13 +175,13 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext {
181175
val standardizer2 = new StandardScaler(withMean = true, withStd = false)
182176
val standardizer3 = new StandardScaler(withMean = false, withStd = true)
183177

184-
standardizer1.fit(dataRDD)
185-
standardizer2.fit(dataRDD)
186-
standardizer3.fit(dataRDD)
178+
val model1 = standardizer1.fit(dataRDD)
179+
val model2 = standardizer2.fit(dataRDD)
180+
val model3 = standardizer3.fit(dataRDD)
187181

188-
val data1 = data.map(standardizer1.transform)
189-
val data2 = data.map(standardizer2.transform)
190-
val data3 = data.map(standardizer3.transform)
182+
val data1 = data.map(model1.transform)
183+
val data2 = data.map(model2.transform)
184+
val data3 = data.map(model3.transform)
191185

192186
assert(data1.forall(_.toArray.forall(_ == 0.0)),
193187
"The variance is zero, so the transformed result should be 0.0")

0 commit comments

Comments
 (0)