@@ -10,6 +10,7 @@ use log::{debug, error, trace, warn};
1010use std:: {
1111 cmp:: max, io, iter:: FromIterator , sync:: Arc , thread:: sleep, time:: Duration ,
1212} ;
13+ use strum_macros:: EnumString ;
1314use tokio:: {
1415 prelude:: * ,
1516 sync:: mpsc:: { self , Receiver , Sender } ,
@@ -32,6 +33,19 @@ const CONCURRENCY: usize = 48;
3233/// The number of addresses to pass to SmartyStreets at one time.
3334const GEOCODE_SIZE : usize = 72 ;
3435
36+ /// What should we do if a geocoding output column has the same as a column in
37+ /// the input?
38+ #[ derive( Debug , Clone , Copy , EnumString , Eq , PartialEq ) ]
39+ #[ strum( serialize_all = "snake_case" ) ]
40+ pub enum OnDuplicateColumns {
41+ /// Fail with an error.
42+ Error ,
43+ /// Replace existing columns with the same name.
44+ Replace ,
45+ /// Leave the old columns in place and append the new ones.
46+ Append ,
47+ }
48+
3549/// Data about the CSV file that we include with every chunk to be geocoded.
3650struct Shared {
3751 /// Which columns contain addresses that we need to geocode?
@@ -65,7 +79,7 @@ enum Message {
6579pub async fn geocode_stdio (
6680 spec : AddressColumnSpec < String > ,
6781 match_strategy : MatchStrategy ,
68- replace_existing_columns : bool ,
82+ on_duplicate_columns : OnDuplicateColumns ,
6983 structure : Structure ,
7084) -> Result < ( ) > {
7185 // Set up bounded channels for communication between the sync and async
@@ -76,7 +90,7 @@ pub async fn geocode_stdio(
7690 // Hook up our inputs and outputs, which are synchronous functions running
7791 // in their own threads.
7892 let read_fut = run_sync_fn_in_background ( "read CSV" . to_owned ( ) , move || {
79- read_csv_from_stdin ( spec, structure, replace_existing_columns , in_tx)
93+ read_csv_from_stdin ( spec, structure, on_duplicate_columns , in_tx)
8094 } ) ;
8195 let write_fut = run_sync_fn_in_background ( "write CSV" . to_owned ( ) , move || {
8296 write_csv_to_stdout ( out_rx)
@@ -152,7 +166,7 @@ pub async fn geocode_stdio(
152166fn read_csv_from_stdin (
153167 spec : AddressColumnSpec < String > ,
154168 structure : Structure ,
155- replace_existing_columns : bool ,
169+ on_duplicate_columns : OnDuplicateColumns ,
156170 mut tx : Sender < Message > ,
157171) -> Result < ( ) > {
158172 // Open up our CSV file and get the headers.
@@ -161,47 +175,56 @@ fn read_csv_from_stdin(
161175 let mut in_headers = rdr. headers ( ) ?. to_owned ( ) ;
162176 debug ! ( "input headers: {:?}" , in_headers) ;
163177
164- // Look for duplicate input columns, and decide what to do.
165- let column_indices_to_remove =
166- spec. column_indices_to_remove ( & structure, & in_headers) ?;
167- let remove_column_flags = if column_indices_to_remove. is_empty ( ) {
168- // No columns to remove!
169- None
170- } else {
171- // We have to remove columns, so make a human-readable list...
172- let mut removed_names = vec ! [ ] ;
173- for i in & column_indices_to_remove {
174- removed_names. push ( & in_headers[ * i] ) ;
175- }
176- let removed_names = removed_names. join ( " " ) ;
177-
178- // And decide whether we're actually allowed to remove them or not.
179- if replace_existing_columns {
180- warn ! ( "removing input columns: {}" , removed_names) ;
178+ // Figure out if we have any duplicate columns.
179+ let ( duplicate_column_indices, duplicate_column_names) = {
180+ let duplicate_columns = spec. duplicate_columns ( & structure, & in_headers) ?;
181+ let indices = duplicate_columns
182+ . iter ( )
183+ . map ( |name_idx| name_idx. 1 )
184+ . collect :: < Vec < _ > > ( ) ;
185+ let names = duplicate_columns
186+ . iter ( )
187+ . map ( |name_idx| name_idx. 0 )
188+ . collect :: < Vec < _ > > ( )
189+ . join ( ", " ) ;
190+ ( indices, names)
191+ } ;
181192
182- // Build the vector of bools specifying whether columns should
183- // stay or go.
184- let mut flags = vec ! [ false ; in_headers. len( ) ] ;
185- for i in column_indices_to_remove {
186- flags[ i] = true ;
193+ // If we do have duplicate columns, figure out what to do about it.
194+ let mut should_remove_columns = false ;
195+ let mut remove_column_flags = vec ! [ false ; in_headers. len( ) ] ;
196+ if !duplicate_column_indices. is_empty ( ) {
197+ match on_duplicate_columns {
198+ OnDuplicateColumns :: Error => {
199+ return Err ( format_err ! (
200+ "input columns would conflict with geocoding columns: {}" ,
201+ duplicate_column_names,
202+ ) ) ;
203+ }
204+ OnDuplicateColumns :: Replace => {
205+ warn ! ( "replacing input columns: {}" , duplicate_column_names) ;
206+ should_remove_columns = true ;
207+ for i in duplicate_column_indices. iter ( ) . cloned ( ) {
208+ remove_column_flags[ i] = true ;
209+ }
210+ }
211+ OnDuplicateColumns :: Append => {
212+ warn ! (
213+ "output contains duplicate columns: {}" ,
214+ duplicate_column_names,
215+ ) ;
187216 }
188- Some ( flags)
189- } else {
190- return Err ( format_err ! (
191- "input columns would conflict with geocoding columns: {}" ,
192- removed_names,
193- ) ) ;
194217 }
195- } ;
218+ }
196219
197220 // Remove any duplicate columns from our input headers.
198- if let Some ( remove_column_flags ) = & remove_column_flags {
221+ if should_remove_columns {
199222 in_headers = remove_columns ( & in_headers, & remove_column_flags) ;
200223 }
201224
202225 // Convert our column spec from using header names to header indices.
203226 //
204- // This needs to use "post-removal" indices !
227+ // This needs to happen _after_ `remove_columns` on our headers !
205228 let spec = spec. convert_to_indices_using_headers ( & in_headers) ?;
206229
207230 // Decide how big to make our chunks. We want to geocode no more
@@ -228,9 +251,9 @@ fn read_csv_from_stdin(
228251 let mut rows = Vec :: with_capacity ( chunk_size) ;
229252 for row in rdr. records ( ) {
230253 let mut row = row?;
231- if let Some ( remove_column_flags ) = & remove_column_flags {
254+ if should_remove_columns {
232255 // Strip out any duplicate columns.
233- row = remove_columns ( & row, remove_column_flags) ;
256+ row = remove_columns ( & row, & remove_column_flags) ;
234257 }
235258 rows. push ( row) ;
236259 if rows. len ( ) >= chunk_size {
@@ -270,6 +293,20 @@ fn read_csv_from_stdin(
270293 Ok ( ( ) )
271294}
272295
296+ /// Remove columns from `row` if they're set to true in `remove_column_flags`.
297+ fn remove_columns ( row : & StringRecord , remove_column_flags : & [ bool ] ) -> StringRecord {
298+ debug_assert_eq ! ( row. len( ) , remove_column_flags. len( ) ) ;
299+ StringRecord :: from_iter ( row. iter ( ) . zip ( remove_column_flags) . filter_map (
300+ |( value, & remove) | {
301+ if remove {
302+ None
303+ } else {
304+ Some ( value. to_owned ( ) )
305+ }
306+ } ,
307+ ) )
308+ }
309+
273310/// Receive chunks of a CSV file from `rx` and write them to standard output.
274311fn write_csv_to_stdout ( mut rx : Receiver < Message > ) -> Result < ( ) > {
275312 let stdout = io:: stdout ( ) ;
@@ -410,15 +447,3 @@ async fn geocode_chunk(
410447 }
411448 Ok ( chunk)
412449}
413-
414- fn remove_columns ( row : & StringRecord , remove_column_flags : & [ bool ] ) -> StringRecord {
415- StringRecord :: from_iter ( row. iter ( ) . zip ( remove_column_flags) . filter_map (
416- |( value, & remove) | {
417- if remove {
418- None
419- } else {
420- Some ( value. to_owned ( ) )
421- }
422- } ,
423- ) )
424- }
0 commit comments