Skip to content

Commit 8b75731

Browse files
committed
Use merge_sorted() to optimize PiecewiseLinearFn
1 parent f9f6eb1 commit 8b75731

File tree

1 file changed

+42
-17
lines changed

1 file changed

+42
-17
lines changed

src/misc.rs

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ pub fn slice_upper_bound<T: PartialOrd>(slice: &[T], key: &T) -> usize {
2121
.unwrap_err()
2222
}
2323

24+
/// Merge two sorted collections into one
25+
pub fn merge_sorted<T: PartialOrd>(
26+
i1: impl IntoIterator<Item = T>,
27+
i2: impl IntoIterator<Item = T>,
28+
) -> Vec<T> {
29+
let (mut i1, mut i2) = (i1.into_iter().peekable(), i2.into_iter().peekable());
30+
let mut merged = Vec::with_capacity(i1.size_hint().0 + i2.size_hint().0);
31+
while let (Some(a), Some(b)) = (i1.peek(), i2.peek()) {
32+
merged.push(if a <= b { i1.next() } else { i2.next() }.unwrap());
33+
}
34+
merged.extend(i1.chain(i2));
35+
merged
36+
}
37+
2438
/// A simple data structure for coordinate compression
2539
pub struct SparseIndex {
2640
coords: Vec<i64>,
@@ -41,7 +55,7 @@ impl SparseIndex {
4155
}
4256
}
4357

44-
/// Represents a minimum (lower envelope) of a collection of linear functions of a variable,
58+
/// Represents a maximum (upper envelope) of a collection of linear functions of a variable,
4559
/// evaluated using the convex hull trick with square root decomposition.
4660
pub struct PiecewiseLinearFn {
4761
sorted_lines: Vec<(f64, f64)>,
@@ -63,19 +77,19 @@ impl PiecewiseLinearFn {
6377
}
6478
}
6579

66-
/// Replaces the represented function with the minimum of itself and a provided line
67-
pub fn min_with(&mut self, slope: f64, intercept: f64) {
80+
/// Replaces the represented function with the maximum of itself and a provided line
81+
pub fn max_with(&mut self, slope: f64, intercept: f64) {
6882
self.recent_lines.push((slope, intercept));
6983
}
7084

7185
fn update_envelope(&mut self) {
72-
self.recent_lines.extend(self.sorted_lines.drain(..));
73-
self.recent_lines.sort_unstable_by(asserting_cmp); // TODO: switch to O(n) merge
86+
self.recent_lines.sort_unstable_by(asserting_cmp);
7487
self.intersections.clear();
7588

76-
for (new_m, new_b) in self.recent_lines.drain(..).rev() {
89+
for (new_m, new_b) in merge_sorted(self.recent_lines.drain(..), self.sorted_lines.drain(..))
90+
{
7791
while let Some(&(last_m, last_b)) = self.sorted_lines.last() {
78-
// If slopes are equal, get rid of the old line as its intercept is higher
92+
// If slopes are equal, get rid of the old line as its intercept is lower
7993
if (new_m - last_m).abs() > 1e-9 {
8094
let intr = (new_b - last_b) / (last_m - new_m);
8195
if self.intersections.last() < Some(&intr) {
@@ -96,8 +110,8 @@ impl PiecewiseLinearFn {
96110
.iter()
97111
.chain(self.sorted_lines.get(idx))
98112
.map(|&(m, b)| m * x + b)
99-
.min_by(asserting_cmp)
100-
.unwrap_or(1e18)
113+
.max_by(asserting_cmp)
114+
.unwrap_or(-1e18)
101115
}
102116

103117
/// Evaluates the function at x
@@ -129,6 +143,17 @@ mod test {
129143
}
130144
}
131145

146+
#[test]
147+
fn test_merge_sorted() {
148+
let vals1 = vec![16, 45, 45, 82];
149+
let vals2 = vec![-20, 40, 45, 50];
150+
let vals_merged = vec![-20, 16, 40, 45, 45, 45, 50, 82];
151+
152+
assert_eq!(merge_sorted(None, Some(42)), vec![42]);
153+
assert_eq!(merge_sorted(vals1.iter().cloned(), None), vals1);
154+
assert_eq!(merge_sorted(vals1, vals2), vals_merged);
155+
}
156+
132157
#[test]
133158
fn test_coord_compress() {
134159
let mut coords = vec![16, 99, 45, 18];
@@ -153,22 +178,22 @@ mod test {
153178

154179
#[test]
155180
fn test_convex_hull_trick() {
156-
let lines = [(0, 3), (1, 0), (-1, 8), (2, -1), (-1, 4)];
181+
let lines = [(0, -3), (-1, 0), (1, -8), (-2, 1), (1, -4)];
157182
let xs = [0, 1, 2, 3, 4, 5];
158183
// results[i] consists of the expected y-coordinates after processing
159184
// the first i+1 lines.
160185
let results = [
161-
[3, 3, 3, 3, 3, 3],
162-
[0, 1, 2, 3, 3, 3],
163-
[0, 1, 2, 3, 3, 3],
164-
[-1, 1, 2, 3, 3, 3],
165-
[-1, 1, 2, 1, 0, -1],
186+
[-3, -3, -3, -3, -3, -3],
187+
[0, -1, -2, -3, -3, -3],
188+
[0, -1, -2, -3, -3, -3],
189+
[1, -1, -2, -3, -3, -3],
190+
[1, -1, -2, -1, 0, 1],
166191
];
167192
for threshold in 0..=lines.len() {
168193
let mut func = PiecewiseLinearFn::with_merge_threshold(threshold);
169-
assert_eq!(func.evaluate(0.0), 1e18);
194+
assert_eq!(func.evaluate(0.0), -1e18);
170195
for (&(slope, intercept), expected) in lines.iter().zip(results.iter()) {
171-
func.min_with(slope as f64, intercept as f64);
196+
func.max_with(slope as f64, intercept as f64);
172197
let ys: Vec<i64> = xs.iter().map(|&x| func.evaluate(x as f64) as i64).collect();
173198
assert_eq!(expected, &ys[..]);
174199
}

0 commit comments

Comments
 (0)