@@ -19,12 +19,13 @@ package org.apache.spark.ml.regression
1919
2020import org .apache .spark .SparkFunSuite
2121import org .apache .spark .ml .impl .TreeTests
22+ import org .apache .spark .mllib .linalg .Vectors
2223import org .apache .spark .mllib .regression .LabeledPoint
2324import org .apache .spark .mllib .tree .{EnsembleTestHelper , GradientBoostedTrees => OldGBT }
2425import org .apache .spark .mllib .tree .configuration .{Algo => OldAlgo }
2526import org .apache .spark .mllib .util .MLlibTestSparkContext
2627import org .apache .spark .rdd .RDD
27- import org .apache .spark .sql .DataFrame
28+ import org .apache .spark .sql .{ DataFrame , Row }
2829
2930
3031/**
@@ -67,6 +68,26 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
6768 }
6869 }
6970
71+ test(" GBTRegressor behaves reasonably on toy data" ) {
72+ val df = sqlContext.createDataFrame(Seq (
73+ LabeledPoint (10 , Vectors .dense(1 , 2 , 3 , 4 )),
74+ LabeledPoint (- 5 , Vectors .dense(6 , 3 , 2 , 1 )),
75+ LabeledPoint (11 , Vectors .dense(2 , 2 , 3 , 4 )),
76+ LabeledPoint (- 6 , Vectors .dense(6 , 4 , 2 , 1 )),
77+ LabeledPoint (9 , Vectors .dense(1 , 2 , 6 , 4 )),
78+ LabeledPoint (- 4 , Vectors .dense(6 , 3 , 2 , 2 ))
79+ ))
80+ val gbt = new GBTRegressor ()
81+ .setMaxDepth(2 )
82+ .setMaxIter(2 )
83+ val model = gbt.fit(df)
84+ val preds = model.transform(df)
85+ val predictions = preds.select(" prediction" ).map(_.getDouble(0 ))
86+ // Checks based on SPARK-8736 (to ensure it is not doing classification)
87+ assert(predictions.max() > 2 )
88+ assert(predictions.min() < - 1 )
89+ }
90+
7091 // TODO: Reinstate test once runWithValidation is implemented SPARK-7132
7192 /*
7293 test("runWithValidation stops early and performs better on a validation dataset") {
0 commit comments