Skip to content

Commit 0ebf7c1

Browse files
zero323srowen
authored andcommitted
[SPARK-17027][ML] Avoid integer overflow in PolynomialExpansion.getPolySize
## What changes were proposed in this pull request? Replaces custom choose function with o.a.commons.math3.CombinatoricsUtils.binomialCoefficient ## How was this patch tested? Spark unit tests Author: zero323 <zero323@users.noreply.github.com> Closes #14614 from zero323/SPARK-17027.
1 parent cdaa562 commit 0ebf7c1

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.ml.feature
1919

2020
import scala.collection.mutable
2121

22+
import org.apache.commons.math3.util.CombinatoricsUtils
23+
2224
import org.apache.spark.annotation.Since
2325
import org.apache.spark.ml.UnaryTransformer
2426
import org.apache.spark.ml.linalg._
@@ -84,12 +86,12 @@ class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: Str
8486
@Since("1.6.0")
8587
object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] {
8688

87-
private def choose(n: Int, k: Int): Int = {
88-
Range(n, n - k, -1).product / Range(k, 1, -1).product
89+
private def getPolySize(numFeatures: Int, degree: Int): Int = {
90+
val n = CombinatoricsUtils.binomialCoefficient(numFeatures + degree, degree)
91+
require(n <= Integer.MAX_VALUE)
92+
n.toInt
8993
}
9094

91-
private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree)
92-
9395
private def expandDense(
9496
values: Array[Double],
9597
lastIdx: Int,

mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,29 @@ class PolynomialExpansionSuite
116116
.setDegree(3)
117117
testDefaultReadWrite(t)
118118
}
119+
120+
test("SPARK-17027. Integer overflow in PolynomialExpansion.getPolySize") {
121+
val data: Array[(Vector, Int, Int)] = Array(
122+
(Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0), 3002, 4367),
123+
(Vectors.sparse(5, Seq((0, 1.0), (4, 5.0))), 3002, 4367),
124+
(Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375)
125+
)
126+
127+
val df = spark.createDataFrame(data)
128+
.toDF("features", "expectedPoly10size", "expectedPoly11size")
129+
130+
val t = new PolynomialExpansion()
131+
.setInputCol("features")
132+
.setOutputCol("polyFeatures")
133+
134+
for (i <- Seq(10, 11)) {
135+
val transformed = t.setDegree(i)
136+
.transform(df)
137+
.select(s"expectedPoly${i}size", "polyFeatures")
138+
.rdd.map { case Row(expected: Int, v: Vector) => expected == v.size }
139+
140+
assert(transformed.collect.forall(identity))
141+
}
142+
}
119143
}
120144

0 commit comments

Comments
 (0)