@@ -236,9 +236,8 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
236
236
* Only univariate (single feature) algorithm supported.
237
237
*
238
238
* Sequential PAV implementation based on:
239
- * Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani.
240
- * "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61.
241
- * Available from <a href="http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf">here</a>
239
+ * Grotzinger, S. J., and C. Witzgall.
240
+ * "Projections onto order simplexes." Applied mathematics and Optimization 12.1 (1984): 247-270.
242
241
*
243
242
* Sequential PAV parallelization based on:
244
243
* Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset.
@@ -312,90 +311,118 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
312
311
}
313
312
314
313
/**
315
- * Performs a pool adjacent violators algorithm (PAV).
316
- * Uses approach with single processing of data where violators
317
- * in previously processed data created by pooling are fixed immediately.
318
- * Uses optimization of discovering monotonicity violating sequences (blocks).
314
+ * Performs a pool adjacent violators algorithm (PAV). Implements the algorithm originally
315
+ * described in [1], using the formulation from [2, 3]. Uses an array to keep track of start
316
+ * and end indices of blocks.
319
317
*
320
- * @param input Input data of tuples (label, feature, weight).
318
+ * [1] Grotzinger, S. J., and C. Witzgall. "Projections onto order simplexes." Applied
319
+ * mathematics and Optimization 12.1 (1984): 247-270.
320
+ *
321
+ * [2] Best, Michael J., and Nilotpal Chakravarti. "Active set algorithms for isotonic
322
+ * regression; a unifying framework." Mathematical Programming 47.1-3 (1990): 425-439.
323
+ *
324
+ * [3] Best, Michael J., Nilotpal Chakravarti, and Vasant A. Ubhaya. "Minimizing separable convex
325
+ * functions subject to simple chain constraints." SIAM Journal on Optimization 10.3 (2000):
326
+ * 658-672.
327
+ *
328
+ * @param input Input data of tuples (label, feature, weight). Weights must
329
+ be non-negative.
321
330
* @return Result tuples (label, feature, weight) where labels were updated
322
331
* to form a monotone sequence as per isotonic regression definition.
323
332
*/
324
333
private def poolAdjacentViolators (
325
334
input : Array [(Double , Double , Double )]): Array [(Double , Double , Double )] = {
326
335
327
- if (input.isEmpty) {
328
- return Array .empty
336
+ val cleanInput = input.filter{ case (y, x, weight) =>
337
+ require(
338
+ weight >= 0.0 ,
339
+ s " Negative weight at point ( $y, $x, $weight). Weights must be non-negative "
340
+ )
341
+ weight > 0
329
342
}
330
343
331
- // Pools sub array within given bounds assigning weighted average value to all elements.
332
- def pool ( input : Array [( Double , Double , Double )], start : Int , end : Int ) : Unit = {
333
- val poolSubArray = input.slice(start, end + 1 )
344
+ if (cleanInput.isEmpty) {
345
+ return Array .empty
346
+ }
334
347
335
- val weightedSum = poolSubArray.map(lp => lp._1 * lp._3).sum
336
- val weight = poolSubArray.map(_._3).sum
348
+ // Keeps track of the start and end indices of the blocks. if [i, j] is a valid block from
349
+ // cleanInput(i) to cleanInput(j) (inclusive), then blockBounds(i) = j and blockBounds(j) = i
350
+ // Initially, each data point is its own block.
351
+ val blockBounds = Array .range(0 , cleanInput.length)
337
352
338
- var i = start
339
- while (i <= end) {
340
- input(i) = (weightedSum / weight, input(i)._2, input(i)._3)
341
- i = i + 1
342
- }
353
+ // Keep track of the sum of weights and sum of weight * y for each block. weights( start)
354
+ // gives the values for the block. Entries that are not at the start of a block
355
+ // are meaningless.
356
+ val weights : Array [( Double , Double )] = cleanInput.map { case (y, _, weight) =>
357
+ (weight, weight * y)
343
358
}
344
359
345
- var i = 0
346
- val len = input.length
347
- while (i < len) {
348
- var j = i
360
+ // a few convenience functions to make the code more readable
349
361
350
- // Find monotonicity violating sequence, if any.
351
- while (j < len - 1 && input(j)._1 > input(j + 1 )._1) {
352
- j = j + 1
353
- }
362
+ // blockStart and blockEnd have identical implementations. We create two different
363
+ // functions to make the code more expressive
364
+ def blockEnd ( start : Int ) : Int = blockBounds(start)
365
+ def blockStart ( end : Int ) : Int = blockBounds(end)
354
366
355
- // If monotonicity was not violated, move to next data point.
356
- if (i == j) {
357
- i = i + 1
358
- } else {
359
- // Otherwise pool the violating sequence
360
- // and check if pooling caused monotonicity violation in previously processed points.
361
- while (i >= 0 && input(i)._1 > input(i + 1 )._1) {
362
- pool(input, i, j)
363
- i = i - 1
364
- }
367
+ // the next block starts at the index after the end of this block
368
+ def nextBlock (start : Int ): Int = blockEnd(start) + 1
365
369
366
- i = j
367
- }
370
+ // the previous block ends at the index before the start of this block
371
+ // we then use blockStart to find the start
372
+ def prevBlock (start : Int ): Int = blockStart(start - 1 )
373
+
374
+ // Merge two adjacent blocks, updating blockBounds and weights to reflect the merge
375
+ // Return the start index of the merged block
376
+ def merge (block1 : Int , block2 : Int ): Int = {
377
+ assert(
378
+ blockEnd(block1) + 1 == block2,
379
+ s " Attempting to merge non-consecutive blocks [ ${block1}, ${blockEnd(block1)}] " +
380
+ s " and [ ${block2}, ${blockEnd(block2)}]. This is likely a bug in the isotonic regression " +
381
+ " implementation. Please file a bug report."
382
+ )
383
+ blockBounds(block1) = blockEnd(block2)
384
+ blockBounds(blockEnd(block2)) = block1
385
+ val w1 = weights(block1)
386
+ val w2 = weights(block2)
387
+ weights(block1) = (w1._1 + w2._1, w1._2 + w2._2)
388
+ block1
368
389
}
369
390
370
- // For points having the same prediction, we only keep two boundary points.
371
- val compressed = ArrayBuffer .empty[( Double , Double , Double )]
391
+ // average value of a block
392
+ def average ( start : Int ) : Double = weights(start)._2 / weights(start)._1
372
393
373
- var (curLabel, curFeature, curWeight) = input.head
374
- var rightBound = curFeature
375
- def merge (): Unit = {
376
- compressed += ((curLabel, curFeature, curWeight))
377
- if (rightBound > curFeature) {
378
- compressed += ((curLabel, rightBound, 0.0 ))
394
+ // Implement Algorithm PAV from [3].
395
+ // Merge on >= instead of > because it eliminates adjacent blocks with the same average, and we
396
+ // want to compress our output as much as possible. Both give correct results.
397
+ var i = 0
398
+ while (nextBlock(i) < cleanInput.length) {
399
+ if (average(i) >= average(nextBlock(i))) {
400
+ merge(i, nextBlock(i))
401
+ while ((i > 0 ) && (average(prevBlock(i)) >= average(i))) {
402
+ i = merge(prevBlock(i), i)
403
+ }
404
+ } else {
405
+ i = nextBlock(i)
379
406
}
380
407
}
381
- i = 1
382
- while (i < input.length) {
383
- val (label, feature, weight) = input(i)
384
- if (label == curLabel) {
385
- curWeight += weight
386
- rightBound = feature
408
+
409
+ // construct the output by walking through the blocks in order
410
+ val output = ArrayBuffer .empty[(Double , Double , Double )]
411
+ i = 0
412
+ while (i < cleanInput.length) {
413
+ // If block size is > 1, a point at the start and end of the block,
414
+ // each receiving half the weight. Otherwise, a single point with
415
+ // all the weight.
416
+ if (cleanInput(blockEnd(i))._2 > cleanInput(i)._2) {
417
+ output += ((average(i), cleanInput(i)._2, weights(i)._1 / 2 ))
418
+ output += ((average(i), cleanInput(blockEnd(i))._2, weights(i)._1 / 2 ))
387
419
} else {
388
- merge()
389
- curLabel = label
390
- curFeature = feature
391
- curWeight = weight
392
- rightBound = curFeature
420
+ output += ((average(i), cleanInput(i)._2, weights(i)._1))
393
421
}
394
- i += 1
422
+ i = nextBlock(i)
395
423
}
396
- merge()
397
424
398
- compressed .toArray
425
+ output .toArray
399
426
}
400
427
401
428
/**
0 commit comments