Skip to content

Commit 2573e8d

Browse files
committed
Update the scala side of the pythonmllibapi and make the test a bit nicer too
1 parent 3a09170 commit 2573e8d

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ private[python] class PythonMLLibAPI extends Serializable {
399399
val sigma = si.map(_.asInstanceOf[DenseMatrix])
400400
val gaussians = Array.tabulate(weight.length){
401401
i => new MultivariateGaussian(mean(i), sigma(i))
402-
}
402+
}
403403
val model = new GaussianMixtureModel(weight, gaussians)
404404
model.predictSoft(data).map(Vectors.dense)
405405
}
@@ -494,7 +494,7 @@ private[python] class PythonMLLibAPI extends Serializable {
494494
def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = {
495495
new Normalizer(p).transform(rdd)
496496
}
497-
497+
498498
/**
499499
* Java stub for StandardScaler.fit(). This stub returns a
500500
* handle to the Java object instead of the content of the Java object.
@@ -685,12 +685,14 @@ private[python] class PythonMLLibAPI extends Serializable {
685685
lossStr: String,
686686
numIterations: Int,
687687
learningRate: Double,
688-
maxDepth: Int): GradientBoostedTreesModel = {
688+
maxDepth: Int,
689+
maxBins: Int): GradientBoostedTreesModel = {
689690
val boostingStrategy = BoostingStrategy.defaultParams(algoStr)
690691
boostingStrategy.setLoss(Losses.fromString(lossStr))
691692
boostingStrategy.setNumIterations(numIterations)
692693
boostingStrategy.setLearningRate(learningRate)
693694
boostingStrategy.treeStrategy.setMaxDepth(maxDepth)
695+
boostingStrategy.treeStrategy.setMaxBins(maxBins)
694696
boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap
695697

696698
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)

python/pyspark/mllib/tests.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,12 +446,10 @@ def test_regression(self):
446446
# Verify that maxBins is being passed through
447447
GradientBoostedTrees.trainRegressor(
448448
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32)
449-
try:
449+
with self.assertRaises(Exception) as cm:
450450
GradientBoostedTrees.trainRegressor(
451451
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1)
452-
self.fail("max bins was not passed through (or not verified!)")
453-
except Exception:
454-
self.pass()
452+
455453

456454

457455
class StatTests(MLlibTestCase):

0 commit comments

Comments
 (0)