Skip to content

Commit 93f3994

Browse files
authored
Convex hull trick (#14)
Finished implementation of convex hull trick with sqrt decomposition
1 parent 2cab263 commit 93f3994

File tree

1 file changed

+68
-12
lines changed

1 file changed

+68
-12
lines changed

src/range_query/sqrt_decomp.rs

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,11 @@ impl MoState for DistinctVals {
103103

104104
/// Represents a minimum (lower envelope) of a collection of linear functions of a variable,
105105
/// evaluated using the convex hull trick with square root decomposition.
106+
#[derive(Debug)]
106107
pub struct PiecewiseLinearFn {
107-
sorted_lines: Vec<(i64, i64)>,
108-
recent_lines: Vec<(i64, i64)>,
108+
sorted_lines: Vec<(f64, f64)>,
109+
intersections: Vec<f64>,
110+
recent_lines: Vec<(f64, f64)>,
109111
merge_threshold: usize,
110112
}
111113

@@ -116,31 +118,67 @@ impl PiecewiseLinearFn {
116118
pub fn with_merge_threshold(merge_threshold: usize) -> Self {
117119
Self {
118120
sorted_lines: vec![],
121+
intersections: vec![],
119122
recent_lines: vec![],
120123
merge_threshold,
121124
}
122125
}
123126

124127
/// Replaces this function with the minimum of itself and a provided line
125-
pub fn min_with(&mut self, slope: i64, intercept: i64) {
128+
pub fn min_with(&mut self, slope: f64, intercept: f64) {
126129
self.recent_lines.push((slope, intercept));
127130
}
128131

129132
fn update_envelope(&mut self) {
130133
self.recent_lines.extend(self.sorted_lines.drain(..));
131-
self.recent_lines.sort_unstable();
132-
for (slope, intercept) in self.recent_lines.drain(..) {
133-
// TODO: do convex hull trick algorithm
134-
self.sorted_lines.push((slope, intercept));
134+
self.recent_lines
135+
.sort_unstable_by(|x, y| y.partial_cmp(&x).unwrap());
136+
self.intersections.clear();
137+
138+
for (m1, b1) in self.recent_lines.drain(..) {
139+
while let Some(&(m2, b2)) = self.sorted_lines.last() {
140+
// If slopes are equal, the later line will always have lower
141+
// intercept, so we can get rid of the old one.
142+
if (m1 - m2).abs() > 1e-10f64 {
143+
let new_intersection = (b1 - b2) / (m2 - m1);
144+
if &new_intersection > self.intersections.last().unwrap_or(&f64::MIN) {
145+
self.intersections.push(new_intersection);
146+
break;
147+
}
148+
}
149+
self.intersections.pop();
150+
self.sorted_lines.pop();
151+
}
152+
self.sorted_lines.push((m1, b1));
135153
}
136154
}
137155

138-
fn eval_helper(&self, x: i64) -> i64 {
139-
0 // TODO: pick actual minimum, or infinity if empty
156+
fn eval_in_envelope(&self, x: f64) -> f64 {
157+
if self.sorted_lines.is_empty() {
158+
return f64::MAX;
159+
}
160+
let idx = match self
161+
.intersections
162+
.binary_search_by(|y| y.partial_cmp(&x).unwrap())
163+
{
164+
Ok(k) => k,
165+
Err(k) => k,
166+
};
167+
let (m, b) = self.sorted_lines[idx];
168+
m * x + b
169+
}
170+
171+
fn eval_helper(&self, x: f64) -> f64 {
172+
self.recent_lines
173+
.iter()
174+
.map(|&(m, b)| m * x + b)
175+
.min_by(|x, y| x.partial_cmp(y).unwrap())
176+
.unwrap_or(f64::MAX)
177+
.min(self.eval_in_envelope(x))
140178
}
141179

142180
/// Evaluates the function at x
143-
pub fn evaluate(&mut self, x: i64) -> i64 {
181+
pub fn evaluate(&mut self, x: f64) -> f64 {
144182
if self.recent_lines.len() > self.merge_threshold {
145183
self.update_envelope();
146184
}
@@ -164,7 +202,25 @@ mod test {
164202

165203
#[test]
166204
fn test_convex_hull_trick() {
167-
let mut func = PiecewiseLinearFn::with_merge_threshold(3);
168-
// TODO: make test
205+
let lines = [(0, 3), (1, 0), (-1, 8), (2, -1), (-1, 4)];
206+
let xs = [0, 1, 2, 3, 4, 5];
207+
// results[i] consists of the expected y-coordinates after processing
208+
// the first i+1 lines.
209+
let results = [
210+
[3, 3, 3, 3, 3, 3],
211+
[0, 1, 2, 3, 3, 3],
212+
[0, 1, 2, 3, 3, 3],
213+
[-1, 1, 2, 3, 3, 3],
214+
[-1, 1, 2, 1, 0, -1],
215+
];
216+
for threshold in 0..=lines.len() {
217+
let mut func = PiecewiseLinearFn::with_merge_threshold(threshold);
218+
assert_eq!(func.evaluate(0.0), f64::MAX);
219+
for (&(slope, intercept), expected) in lines.iter().zip(results.iter()) {
220+
func.min_with(slope as f64, intercept as f64);
221+
let ys: Vec<i64> = xs.iter().map(|&x| func.evaluate(x as f64) as i64).collect();
222+
assert_eq!(expected, &ys[..]);
223+
}
224+
}
169225
}
170226
}

0 commit comments

Comments
 (0)