@@ -169,7 +169,7 @@ private[spark] object GradientBoostedTrees extends Logging {
169
169
* @param loss evaluation metric.
170
170
* @return Measure of model error on data
171
171
*/
172
- def computeError (
172
+ def computeWeightedError (
173
173
data : RDD [Instance ],
174
174
trees : Array [DecisionTreeRegressionModel ],
175
175
treeWeights : Array [Double ],
@@ -179,7 +179,7 @@ private[spark] object GradientBoostedTrees extends Logging {
179
179
updatePrediction(features, acc, model, weight)
180
180
}
181
181
(loss.computeError(predicted, label) * weight, weight)
182
- }.treeReduce{ case ((err1, weight1), (err2, weight2)) =>
182
+ }.treeReduce { case ((err1, weight1), (err2, weight2)) =>
183
183
(err1 + err2, weight1 + weight2)
184
184
}
185
185
errSum / weightSum
@@ -191,13 +191,13 @@ private[spark] object GradientBoostedTrees extends Logging {
191
191
* @param predError Prediction and error.
192
192
* @return Measure of model error on data
193
193
*/
194
- def computeError (
194
+ def computeWeightedError (
195
195
data : RDD [Instance ],
196
196
predError : RDD [(Double , Double )]): Double = {
197
197
val (errSum, weightSum) = data.zip(predError).map {
198
198
case (Instance (_, weight, _), (_, err)) =>
199
199
(err * weight, weight)
200
- }.treeReduce{ case ((err1, weight1), (err2, weight2)) =>
200
+ }.treeReduce { case ((err1, weight1), (err2, weight2)) =>
201
201
(err1 + err2, weight1 + weight2)
202
202
}
203
203
errSum / weightSum
@@ -220,24 +220,18 @@ private[spark] object GradientBoostedTrees extends Logging {
220
220
treeWeights : Array [Double ],
221
221
loss : OldLoss ,
222
222
algo : OldAlgo .Value ): Array [Double ] = {
223
-
224
- val sc = data.sparkContext
225
223
val remappedData = algo match {
226
224
case OldAlgo .Classification =>
227
225
data.map(x => Instance ((x.label * 2 ) - 1 , x.weight, x.features))
228
226
case _ => data
229
227
}
230
228
231
- val broadcastTrees = sc.broadcast(trees)
232
- val localTreeWeights = treeWeights
233
229
val numTrees = trees.length
234
-
235
230
val (errSum, weightSum) = remappedData.mapPartitions { iter =>
236
- val trees = broadcastTrees.value
237
231
iter.map { case Instance (label, weight, features) =>
238
232
val pred = Array .tabulate(numTrees) { i =>
239
233
trees(i).rootNode.predictImpl(features)
240
- .prediction * localTreeWeights (i)
234
+ .prediction * treeWeights (i)
241
235
}
242
236
val err = pred.scanLeft(0.0 )(_ + _).drop(1 )
243
237
.map(p => loss.computeError(p, label) * weight)
@@ -248,7 +242,6 @@ private[spark] object GradientBoostedTrees extends Logging {
248
242
(err1, weight1 + weight2)
249
243
}
250
244
251
- broadcastTrees.destroy()
252
245
errSum.map(_ / weightSum)
253
246
}
254
247
@@ -298,8 +291,10 @@ private[spark] object GradientBoostedTrees extends Logging {
298
291
}
299
292
300
293
// Prepare periodic checkpointers
294
+ // Note: this is checkpointing the unweighted training error
301
295
val predErrorCheckpointer = new PeriodicRDDCheckpointer [(Double , Double )](
302
296
treeStrategy.getCheckpointInterval, input.sparkContext)
297
+ // Note: this is checkpointing the unweighted validation error
303
298
val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer [(Double , Double )](
304
299
treeStrategy.getCheckpointInterval, input.sparkContext)
305
300
@@ -319,15 +314,19 @@ private[spark] object GradientBoostedTrees extends Logging {
319
314
320
315
var predError = computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
321
316
predErrorCheckpointer.update(predError)
322
- logDebug(" error of gbt = " + computeError (input, predError))
317
+ logDebug(" error of gbt = " + computeWeightedError (input, predError))
323
318
324
319
// Note: A model of type regression is used since we require raw prediction
325
320
timer.stop(" building tree 0" )
326
321
327
322
var validatePredError =
328
323
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
329
324
if (validate) validatePredErrorCheckpointer.update(validatePredError)
330
- var bestValidateError = if (validate) computeError(validationInput, validatePredError) else 0.0
325
+ var bestValidateError = if (validate) {
326
+ computeWeightedError(validationInput, validatePredError)
327
+ } else {
328
+ 0.0
329
+ }
331
330
var bestM = 1
332
331
333
332
var m = 1
@@ -356,7 +355,7 @@ private[spark] object GradientBoostedTrees extends Logging {
356
355
predError = updatePredictionError(
357
356
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
358
357
predErrorCheckpointer.update(predError)
359
- logDebug(" error of gbt = " + computeError (input, predError))
358
+ logDebug(" error of gbt = " + computeWeightedError (input, predError))
360
359
361
360
if (validate) {
362
361
// Stop training early if
@@ -367,7 +366,7 @@ private[spark] object GradientBoostedTrees extends Logging {
367
366
validatePredError = updatePredictionError(
368
367
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
369
368
validatePredErrorCheckpointer.update(validatePredError)
370
- val currentValidateError = computeError (validationInput, validatePredError)
369
+ val currentValidateError = computeWeightedError (validationInput, validatePredError)
371
370
if (bestValidateError - currentValidateError < validationTol * Math .max(
372
371
currentValidateError, 0.01 )) {
373
372
doneLearning = true
0 commit comments