@@ -103,9 +103,11 @@ impl MoState for DistinctVals {
103
103
104
104
/// Represents a minimum (lower envelope) of a collection of linear functions of a variable,
105
105
/// evaluated using the convex hull trick with square root decomposition.
106
+ #[ derive( Debug ) ]
106
107
pub struct PiecewiseLinearFn {
107
- sorted_lines : Vec < ( i64 , i64 ) > ,
108
- recent_lines : Vec < ( i64 , i64 ) > ,
108
+ sorted_lines : Vec < ( f64 , f64 ) > ,
109
+ intersections : Vec < f64 > ,
110
+ recent_lines : Vec < ( f64 , f64 ) > ,
109
111
merge_threshold : usize ,
110
112
}
111
113
@@ -116,31 +118,67 @@ impl PiecewiseLinearFn {
116
118
pub fn with_merge_threshold ( merge_threshold : usize ) -> Self {
117
119
Self {
118
120
sorted_lines : vec ! [ ] ,
121
+ intersections : vec ! [ ] ,
119
122
recent_lines : vec ! [ ] ,
120
123
merge_threshold,
121
124
}
122
125
}
123
126
124
127
/// Replaces this function with the minimum of itself and a provided line
125
- pub fn min_with ( & mut self , slope : i64 , intercept : i64 ) {
128
+ pub fn min_with ( & mut self , slope : f64 , intercept : f64 ) {
126
129
self . recent_lines . push ( ( slope, intercept) ) ;
127
130
}
128
131
129
132
fn update_envelope ( & mut self ) {
130
133
self . recent_lines . extend ( self . sorted_lines . drain ( ..) ) ;
131
- self . recent_lines . sort_unstable ( ) ;
132
- for ( slope, intercept) in self . recent_lines . drain ( ..) {
133
- // TODO: do convex hull trick algorithm
134
- self . sorted_lines . push ( ( slope, intercept) ) ;
134
+ self . recent_lines
135
+ . sort_unstable_by ( |x, y| y. partial_cmp ( & x) . unwrap ( ) ) ;
136
+ self . intersections . clear ( ) ;
137
+
138
+ for ( m1, b1) in self . recent_lines . drain ( ..) {
139
+ while let Some ( & ( m2, b2) ) = self . sorted_lines . last ( ) {
140
+ // If slopes are equal, the later line will always have lower
141
+ // intercept, so we can get rid of the old one.
142
+ if ( m1 - m2) . abs ( ) > 1e-10f64 {
143
+ let new_intersection = ( b1 - b2) / ( m2 - m1) ;
144
+ if & new_intersection > self . intersections . last ( ) . unwrap_or ( & f64:: MIN ) {
145
+ self . intersections . push ( new_intersection) ;
146
+ break ;
147
+ }
148
+ }
149
+ self . intersections . pop ( ) ;
150
+ self . sorted_lines . pop ( ) ;
151
+ }
152
+ self . sorted_lines . push ( ( m1, b1) ) ;
135
153
}
136
154
}
137
155
138
- fn eval_helper ( & self , x : i64 ) -> i64 {
139
- 0 // TODO: pick actual minimum, or infinity if empty
156
+ fn eval_in_envelope ( & self , x : f64 ) -> f64 {
157
+ if self . sorted_lines . is_empty ( ) {
158
+ return f64:: MAX ;
159
+ }
160
+ let idx = match self
161
+ . intersections
162
+ . binary_search_by ( |y| y. partial_cmp ( & x) . unwrap ( ) )
163
+ {
164
+ Ok ( k) => k,
165
+ Err ( k) => k,
166
+ } ;
167
+ let ( m, b) = self . sorted_lines [ idx] ;
168
+ m * x + b
169
+ }
170
+
171
+ fn eval_helper ( & self , x : f64 ) -> f64 {
172
+ self . recent_lines
173
+ . iter ( )
174
+ . map ( |& ( m, b) | m * x + b)
175
+ . min_by ( |x, y| x. partial_cmp ( y) . unwrap ( ) )
176
+ . unwrap_or ( f64:: MAX )
177
+ . min ( self . eval_in_envelope ( x) )
140
178
}
141
179
142
180
/// Evaluates the function at x
143
- pub fn evaluate ( & mut self , x : i64 ) -> i64 {
181
+ pub fn evaluate ( & mut self , x : f64 ) -> f64 {
144
182
if self . recent_lines . len ( ) > self . merge_threshold {
145
183
self . update_envelope ( ) ;
146
184
}
@@ -164,7 +202,25 @@ mod test {
164
202
165
203
#[ test]
166
204
fn test_convex_hull_trick ( ) {
167
- let mut func = PiecewiseLinearFn :: with_merge_threshold ( 3 ) ;
168
- // TODO: make test
205
+ let lines = [ ( 0 , 3 ) , ( 1 , 0 ) , ( -1 , 8 ) , ( 2 , -1 ) , ( -1 , 4 ) ] ;
206
+ let xs = [ 0 , 1 , 2 , 3 , 4 , 5 ] ;
207
+ // results[i] consists of the expected y-coordinates after processing
208
+ // the first i+1 lines.
209
+ let results = [
210
+ [ 3 , 3 , 3 , 3 , 3 , 3 ] ,
211
+ [ 0 , 1 , 2 , 3 , 3 , 3 ] ,
212
+ [ 0 , 1 , 2 , 3 , 3 , 3 ] ,
213
+ [ -1 , 1 , 2 , 3 , 3 , 3 ] ,
214
+ [ -1 , 1 , 2 , 1 , 0 , -1 ] ,
215
+ ] ;
216
+ for threshold in 0 ..=lines. len ( ) {
217
+ let mut func = PiecewiseLinearFn :: with_merge_threshold ( threshold) ;
218
+ assert_eq ! ( func. evaluate( 0.0 ) , f64 :: MAX ) ;
219
+ for ( & ( slope, intercept) , expected) in lines. iter ( ) . zip ( results. iter ( ) ) {
220
+ func. min_with ( slope as f64 , intercept as f64 ) ;
221
+ let ys: Vec < i64 > = xs. iter ( ) . map ( |& x| func. evaluate ( x as f64 ) as i64 ) . collect ( ) ;
222
+ assert_eq ! ( expected, & ys[ ..] ) ;
223
+ }
224
+ }
169
225
}
170
226
}
0 commit comments