Skip to content

Commit 3c07243

Browse files
committed
update Instance scope
1 parent f0d890a commit 3c07243

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.Vector
2626
* @param weight The weight of this instance.
2727
* @param features The vector of features for this data point.
2828
*/
29-
private[ml] case class Instance(label: Double, weight: Double, features: Vector)
29+
private[spark] case class Instance(label: Double, weight: Double, features: Vector)
3030

3131
/**
3232
* Case class that represents an instance of data point with

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import com.github.fommil.netlib.BLAS
2121

2222
import org.apache.spark.{SparkException, SparkFunSuite}
2323
import org.apache.spark.ml.classification.LinearSVCSuite.generateSVMInput
24-
import org.apache.spark.ml.feature.LabeledPoint
24+
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
2525
import org.apache.spark.ml.linalg.{Vector, Vectors}
2626
import org.apache.spark.ml.param.ParamsSuite
2727
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
@@ -401,8 +401,10 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
401401
model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses)
402402

403403
val evalArr = model3.evaluateEachIteration(validationData.toDF)
404-
val remappedValidationData = validationData.map(
405-
x => LabeledPoint((x.label * 2) - 1, x.features).toInstance)
404+
val remappedValidationData = validationData.map {
405+
case LabeledPoint(label, features) =>
406+
Instance(label * 2 - 1, 1.0, features)
407+
}
406408
val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData,
407409
model1.trees, model1.treeWeights, model1.getOldLossType)
408410
val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData,
@@ -437,8 +439,10 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
437439
assert(modelWithValidation.numTrees < numIter)
438440

439441
val (errorWithoutValidation, errorWithValidation) = {
440-
val remappedRdd = validationData.map(x =>
441-
LabeledPoint(2 * x.label - 1, x.features).toInstance)
442+
val remappedRdd = validationData.map {
443+
case LabeledPoint(label, features) =>
444+
Instance(label * 2 - 1, 1.0, features)
445+
}
442446
(GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees,
443447
modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType),
444448
GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees,

0 commit comments

Comments
 (0)