Skip to content

Commit 02457a7

Browse files
committed
address comments
1 parent 3c07243 commit 02457a7

File tree

4 files changed

+29
-30
lines changed

4 files changed

+29
-30
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ private[spark] object GradientBoostedTrees extends Logging {
169169
* @param loss evaluation metric.
170170
* @return Measure of model error on data
171171
*/
172-
def computeError(
172+
def computeWeightedError(
173173
data: RDD[Instance],
174174
trees: Array[DecisionTreeRegressionModel],
175175
treeWeights: Array[Double],
@@ -179,7 +179,7 @@ private[spark] object GradientBoostedTrees extends Logging {
179179
updatePrediction(features, acc, model, weight)
180180
}
181181
(loss.computeError(predicted, label) * weight, weight)
182-
}.treeReduce{ case ((err1, weight1), (err2, weight2)) =>
182+
}.treeReduce { case ((err1, weight1), (err2, weight2)) =>
183183
(err1 + err2, weight1 + weight2)
184184
}
185185
errSum / weightSum
@@ -191,13 +191,13 @@ private[spark] object GradientBoostedTrees extends Logging {
191191
* @param predError Prediction and error.
192192
* @return Measure of model error on data
193193
*/
194-
def computeError(
194+
def computeWeightedError(
195195
data: RDD[Instance],
196196
predError: RDD[(Double, Double)]): Double = {
197197
val (errSum, weightSum) = data.zip(predError).map {
198198
case (Instance(_, weight, _), (_, err)) =>
199199
(err * weight, weight)
200-
}.treeReduce{ case ((err1, weight1), (err2, weight2)) =>
200+
}.treeReduce { case ((err1, weight1), (err2, weight2)) =>
201201
(err1 + err2, weight1 + weight2)
202202
}
203203
errSum / weightSum
@@ -220,24 +220,18 @@ private[spark] object GradientBoostedTrees extends Logging {
220220
treeWeights: Array[Double],
221221
loss: OldLoss,
222222
algo: OldAlgo.Value): Array[Double] = {
223-
224-
val sc = data.sparkContext
225223
val remappedData = algo match {
226224
case OldAlgo.Classification =>
227225
data.map(x => Instance((x.label * 2) - 1, x.weight, x.features))
228226
case _ => data
229227
}
230228

231-
val broadcastTrees = sc.broadcast(trees)
232-
val localTreeWeights = treeWeights
233229
val numTrees = trees.length
234-
235230
val (errSum, weightSum) = remappedData.mapPartitions { iter =>
236-
val trees = broadcastTrees.value
237231
iter.map { case Instance(label, weight, features) =>
238232
val pred = Array.tabulate(numTrees) { i =>
239233
trees(i).rootNode.predictImpl(features)
240-
.prediction * localTreeWeights(i)
234+
.prediction * treeWeights(i)
241235
}
242236
val err = pred.scanLeft(0.0)(_ + _).drop(1)
243237
.map(p => loss.computeError(p, label) * weight)
@@ -248,7 +242,6 @@ private[spark] object GradientBoostedTrees extends Logging {
248242
(err1, weight1 + weight2)
249243
}
250244

251-
broadcastTrees.destroy()
252245
errSum.map(_ / weightSum)
253246
}
254247

@@ -298,8 +291,10 @@ private[spark] object GradientBoostedTrees extends Logging {
298291
}
299292

300293
// Prepare periodic checkpointers
294+
// Note: this is checkpointing the unweighted training error
301295
val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
302296
treeStrategy.getCheckpointInterval, input.sparkContext)
297+
// Note: this is checkpointing the unweighted validation error
303298
val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
304299
treeStrategy.getCheckpointInterval, input.sparkContext)
305300

@@ -319,15 +314,19 @@ private[spark] object GradientBoostedTrees extends Logging {
319314

320315
var predError = computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
321316
predErrorCheckpointer.update(predError)
322-
logDebug("error of gbt = " + computeError(input, predError))
317+
logDebug("error of gbt = " + computeWeightedError(input, predError))
323318

324319
// Note: A model of type regression is used since we require raw prediction
325320
timer.stop("building tree 0")
326321

327322
var validatePredError =
328323
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
329324
if (validate) validatePredErrorCheckpointer.update(validatePredError)
330-
var bestValidateError = if (validate) computeError(validationInput, validatePredError) else 0.0
325+
var bestValidateError = if (validate) {
326+
computeWeightedError(validationInput, validatePredError)
327+
} else {
328+
0.0
329+
}
331330
var bestM = 1
332331

333332
var m = 1
@@ -356,7 +355,7 @@ private[spark] object GradientBoostedTrees extends Logging {
356355
predError = updatePredictionError(
357356
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
358357
predErrorCheckpointer.update(predError)
359-
logDebug("error of gbt = " + computeError(input, predError))
358+
logDebug("error of gbt = " + computeWeightedError(input, predError))
360359

361360
if (validate) {
362361
// Stop training early if
@@ -367,7 +366,7 @@ private[spark] object GradientBoostedTrees extends Logging {
367366
validatePredError = updatePredictionError(
368367
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
369368
validatePredErrorCheckpointer.update(validatePredError)
370-
val currentValidateError = computeError(validationInput, validatePredError)
369+
val currentValidateError = computeWeightedError(validationInput, validatePredError)
371370
if (bestValidateError - currentValidateError < validationTol * Math.max(
372371
currentValidateError, 0.01)) {
373372
doneLearning = true

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,11 +405,11 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
405405
case LabeledPoint(label, features) =>
406406
Instance(label * 2 - 1, 1.0, features)
407407
}
408-
val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData,
408+
val lossErr1 = GradientBoostedTrees.computeWeightedError(remappedValidationData,
409409
model1.trees, model1.treeWeights, model1.getOldLossType)
410-
val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData,
410+
val lossErr2 = GradientBoostedTrees.computeWeightedError(remappedValidationData,
411411
model2.trees, model2.treeWeights, model2.getOldLossType)
412-
val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData,
412+
val lossErr3 = GradientBoostedTrees.computeWeightedError(remappedValidationData,
413413
model3.trees, model3.treeWeights, model3.getOldLossType)
414414

415415
assert(evalArr(0) ~== lossErr1 relTol 1E-3)
@@ -443,9 +443,9 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
443443
case LabeledPoint(label, features) =>
444444
Instance(label * 2 - 1, 1.0, features)
445445
}
446-
(GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees,
446+
(GradientBoostedTrees.computeWeightedError(remappedRdd, modelWithoutValidation.trees,
447447
modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType),
448-
GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees,
448+
GradientBoostedTrees.computeWeightedError(remappedRdd, modelWithValidation.trees,
449449
modelWithValidation.treeWeights, modelWithValidation.getOldLossType))
450450
}
451451
assert(errorWithValidation < errorWithoutValidation)

mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,11 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
243243

244244
for (evalLossType <- GBTRegressor.supportedLossTypes) {
245245
val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType)
246-
val lossErr1 = GradientBoostedTrees.computeError(validationData.map(_.toInstance),
246+
val lossErr1 = GradientBoostedTrees.computeWeightedError(validationData.map(_.toInstance),
247247
model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType))
248-
val lossErr2 = GradientBoostedTrees.computeError(validationData.map(_.toInstance),
248+
val lossErr2 = GradientBoostedTrees.computeWeightedError(validationData.map(_.toInstance),
249249
model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType))
250-
val lossErr3 = GradientBoostedTrees.computeError(validationData.map(_.toInstance),
250+
val lossErr3 = GradientBoostedTrees.computeWeightedError(validationData.map(_.toInstance),
251251
model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType))
252252

253253
assert(evalArr(0) ~== lossErr1 relTol 1E-3)
@@ -278,11 +278,11 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
278278
// early stop
279279
assert(modelWithValidation.numTrees < numIter)
280280

281-
val errorWithoutValidation = GradientBoostedTrees.computeError(
281+
val errorWithoutValidation = GradientBoostedTrees.computeWeightedError(
282282
validationData.map(_.toInstance),
283283
modelWithoutValidation.trees, modelWithoutValidation.treeWeights,
284284
modelWithoutValidation.getOldLossType)
285-
val errorWithValidation = GradientBoostedTrees.computeError(
285+
val errorWithValidation = GradientBoostedTrees.computeWeightedError(
286286
validationData.map(_.toInstance),
287287
modelWithValidation.trees, modelWithValidation.treeWeights,
288288
modelWithValidation.getOldLossType)

mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
5656
val (errorWithoutValidation, errorWithValidation) = {
5757
if (algo == Classification) {
5858
val remappedRdd = validateRdd.map(x => Instance(2 * x.label - 1, x.weight, x.features))
59-
(GradientBoostedTrees.computeError(remappedRdd, trees, treeWeights, loss),
60-
GradientBoostedTrees.computeError(remappedRdd, validateTrees,
59+
(GradientBoostedTrees.computeWeightedError(remappedRdd, trees, treeWeights, loss),
60+
GradientBoostedTrees.computeWeightedError(remappedRdd, validateTrees,
6161
validateTreeWeights, loss))
6262
} else {
63-
(GradientBoostedTrees.computeError(validateRdd, trees, treeWeights, loss),
64-
GradientBoostedTrees.computeError(validateRdd, validateTrees,
63+
(GradientBoostedTrees.computeWeightedError(validateRdd, trees, treeWeights, loss),
64+
GradientBoostedTrees.computeWeightedError(validateRdd, validateTrees,
6565
validateTreeWeights, loss))
6666
}
6767
}

0 commit comments

Comments
 (0)