Skip to content

Commit b54a586

Browse files
zero323srowen
authored andcommitted
[SPARK-17027][ML] Avoid integer overflow in PolynomialExpansion.getPolySize
Replaces custom choose function with o.a.commons.math3.CombinatoricsUtils.binomialCoefficient Spark unit tests Author: zero323 <zero323@users.noreply.github.com> Closes #14614 from zero323/SPARK-17027. (cherry picked from commit 0ebf7c1) Signed-off-by: Sean Owen <sowen@cloudera.com>
1 parent 8a2b8fc commit b54a586

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.ml.feature
1919

2020
import scala.collection.mutable
2121

22+
import org.apache.commons.math3.util.CombinatoricsUtils
23+
2224
import org.apache.spark.annotation.{Since, Experimental}
2325
import org.apache.spark.ml.UnaryTransformer
2426
import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators}
@@ -80,12 +82,12 @@ class PolynomialExpansion(override val uid: String)
8082
@Since("1.6.0")
8183
object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] {
8284

83-
private def choose(n: Int, k: Int): Int = {
84-
Range(n, n - k, -1).product / Range(k, 1, -1).product
85+
private def getPolySize(numFeatures: Int, degree: Int): Int = {
86+
val n = CombinatoricsUtils.binomialCoefficient(numFeatures + degree, degree)
87+
require(n <= Integer.MAX_VALUE)
88+
n.toInt
8589
}
8690

87-
private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree)
88-
8991
private def expandDense(
9092
values: Array[Double],
9193
lastIdx: Int,

mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,29 @@ class PolynomialExpansionSuite
108108
.setDegree(3)
109109
testDefaultReadWrite(t)
110110
}
111+
112+
test("SPARK-17027. Integer overflow in PolynomialExpansion.getPolySize") {
113+
val data: Array[(Vector, Int, Int)] = Array(
114+
(Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0), 3002, 4367),
115+
(Vectors.sparse(5, Seq((0, 1.0), (4, 5.0))), 3002, 4367),
116+
(Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375)
117+
)
118+
119+
val df = spark.createDataFrame(data)
120+
.toDF("features", "expectedPoly10size", "expectedPoly11size")
121+
122+
val t = new PolynomialExpansion()
123+
.setInputCol("features")
124+
.setOutputCol("polyFeatures")
125+
126+
for (i <- Seq(10, 11)) {
127+
val transformed = t.setDegree(i)
128+
.transform(df)
129+
.select(s"expectedPoly${i}size", "polyFeatures")
130+
.rdd.map { case Row(expected: Int, v: Vector) => expected == v.size }
131+
132+
assert(transformed.collect.forall(identity))
133+
}
134+
}
111135
}
112136

0 commit comments

Comments
 (0)