Skip to content

Commit 4d64c7f

Browse files
committed
Revert "[SPARK-17027][ML] Avoid integer overflow in PolynomialExpansion.getPolySize"
This reverts commit b54a586.
1 parent b54a586 commit 4d64c7f

File tree

2 files changed

+4
-30
lines changed

2 files changed

+4
-30
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ package org.apache.spark.ml.feature
1919

2020
import scala.collection.mutable
2121

22-
import org.apache.commons.math3.util.CombinatoricsUtils
23-
2422
import org.apache.spark.annotation.{Since, Experimental}
2523
import org.apache.spark.ml.UnaryTransformer
2624
import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators}
@@ -82,12 +80,12 @@ class PolynomialExpansion(override val uid: String)
8280
@Since("1.6.0")
8381
object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] {
8482

85-
private def getPolySize(numFeatures: Int, degree: Int): Int = {
86-
val n = CombinatoricsUtils.binomialCoefficient(numFeatures + degree, degree)
87-
require(n <= Integer.MAX_VALUE)
88-
n.toInt
83+
private def choose(n: Int, k: Int): Int = {
84+
Range(n, n - k, -1).product / Range(k, 1, -1).product
8985
}
9086

87+
private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree)
88+
9189
private def expandDense(
9290
values: Array[Double],
9391
lastIdx: Int,

mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -108,29 +108,5 @@ class PolynomialExpansionSuite
108108
.setDegree(3)
109109
testDefaultReadWrite(t)
110110
}
111-
112-
test("SPARK-17027. Integer overflow in PolynomialExpansion.getPolySize") {
113-
val data: Array[(Vector, Int, Int)] = Array(
114-
(Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0), 3002, 4367),
115-
(Vectors.sparse(5, Seq((0, 1.0), (4, 5.0))), 3002, 4367),
116-
(Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375)
117-
)
118-
119-
val df = spark.createDataFrame(data)
120-
.toDF("features", "expectedPoly10size", "expectedPoly11size")
121-
122-
val t = new PolynomialExpansion()
123-
.setInputCol("features")
124-
.setOutputCol("polyFeatures")
125-
126-
for (i <- Seq(10, 11)) {
127-
val transformed = t.setDegree(i)
128-
.transform(df)
129-
.select(s"expectedPoly${i}size", "polyFeatures")
130-
.rdd.map { case Row(expected: Int, v: Vector) => expected == v.size }
131-
132-
assert(transformed.collect.forall(identity))
133-
}
134-
}
135111
}
136112

0 commit comments

Comments
 (0)