Skip to content

Commit 841f1d7

Browse files
WeichenXu123jkbradley
authored andcommitted
[SPARK-22332][ML][TEST] Fix NaiveBayes unit test occasionly fail (cause by test dataset not deterministic)
## What changes were proposed in this pull request? Fix NaiveBayes unit test occasionly fail: Set seed for `BrzMultinomial.sample`, make `generateNaiveBayesInput` output deterministic dataset. (If we do not set seed, the generated dataset will be random, and the model will be possible to exceed the tolerance in the test, which trigger this failure) ## How was this patch tested? Manually run tests multiple times and check each time output models contains the same values. Author: WeichenXu <weichen.xu@databricks.com> Closes apache#19558 from WeichenXu123/fix_nb_test_seed.
1 parent b377ef1 commit 841f1d7

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
2020
import scala.util.Random
2121

2222
import breeze.linalg.{DenseVector => BDV, Vector => BV}
23-
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
23+
import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis}
2424

2525
import org.apache.spark.{SparkException, SparkFunSuite}
2626
import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial}
@@ -335,6 +335,7 @@ object NaiveBayesSuite {
335335
val _pi = pi.map(math.exp)
336336
val _theta = theta.map(row => row.map(math.exp))
337337

338+
implicit val rngForBrzMultinomial = BrzRandBasis.withSeed(seed)
338339
for (i <- 0 until nPoints) yield {
339340
val y = calcLabel(rnd.nextDouble(), _pi)
340341
val xi = modelType match {

0 commit comments

Comments
 (0)