diff --git a/Cargo.lock b/Cargo.lock index 517b5aa..3f396a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -44,6 +44,22 @@ name = "bitflags" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "bstr" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "lazy_static 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "regex-automata 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", + "serde 1.0.101 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "byteorder" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "c2-chacha" version = "0.2.2" @@ -159,6 +175,26 @@ dependencies = [ "lazy_static 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "csv" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "bstr 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)", + "csv-core 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", + "itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)", + "ryu 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "serde 1.0.101 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "csv-core" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "either" version = "1.5.0" @@ -235,6 +271,11 @@ dependencies = [ "regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "itoa" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "lazy_static" version = "1.2.0" @@ -257,6 +298,9 @@ dependencies = [ name = "memchr" version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "libc 0.2.58 (registry+https://github.com/rust-lang/crates.io-index)", +] [[package]] name = "memoffset" @@ -425,16 +469,34 @@ dependencies = [ "thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "regex-automata" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "regex-syntax" version = "0.6.12" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "ryu" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "scopeguard" version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "serde" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "strsim" version = "0.8.0" @@ -526,6 +588,7 @@ name = "ttv" version = "0.2.2" dependencies = [ "clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)", + "csv 1.1.1 (registry+https://github.com/rust-lang/crates.io-index)", "env_logger 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", "flate2 1.0.11 (registry+https://github.com/rust-lang/crates.io-index)", "indicatif 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -606,6 +669,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum arrayvec 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)" = "a1e964f9e24d588183fcb43503abda40d288c8657dfc27311516ce2f05675aef" "checksum atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "9a7d5b8723950951411ee34d271d99dddcc2035a16ab25310ea2c8cfd4369652" "checksum bitflags 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "228047a76f468627ca71776ecdebd732a3423081fcf5125585bcd7c49886ce12" +"checksum bstr 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "8d6c2c5b58ab920a4f5aeaaca34b4488074e8cc7596af94e6f8c6ff247c60245" +"checksum byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "a7c3dd8985a7111efc5c80b44e23ecdd8c007de8ade3b96595387e812b957cf5" "checksum c2-chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7d64d04786e0f528460fc884753cf8dddcc466be308f6026f8e355c41a0e4101" "checksum cc 1.0.25 (registry+https://github.com/rust-lang/crates.io-index)" = "f159dfd43363c4d08055a07703eb7a3406b0dac4d0584d96965a3262db3c9d16" "checksum cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "082bb9b28e00d3c9d39cc03e64ce4cea0f1bb9b3fde493f0cbc008472d22bdf4" @@ -618,6 +683,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum crossbeam-epoch 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "04c9e3102cc2d69cd681412141b390abd55a362afc1540965dad0ad4d34280b4" "checksum crossbeam-queue 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7c979cd6cfe72335896575c6b5688da489e420d36a27a0b9eb0c73db574b4a4b" "checksum crossbeam-utils 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)" = "f8306fcef4a7b563b76b7dd949ca48f52bc1141aa067d2ea09565f3e2652aa5c" +"checksum csv 1.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "37519ccdfd73a75821cac9319d4fce15a81b9fcf75f951df5b9988aa3a0af87d" +"checksum csv-core 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "9b5cadb6b25c77aeff80ba701712494213f4a8418fcda2ee11b6560c3ad0bf4c" "checksum either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3be565ca5c557d7f59e7cfcf1844f9e3033650c929c6566f511e8005f205c1d0" "checksum encode_unicode 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)" = "90b2c9496c001e8cb61827acdefad780795c42264c137744cae6f7d9e3450abd" "checksum env_logger 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "39ecdb7dd54465526f0a56d666e3b2dd5f3a218665a030b6e4ad9e70fa95d8fa" @@ -627,6 +694,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum heck 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ea04fa3ead4e05e51a7c806fc07271fdbde4e246a6c6d1efd52e72230b771b82" "checksum humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f" "checksum indicatif 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a8d596a9576eaa1446996092642d72bfef35cf47243129b7ab883baf5faec31e" +"checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f" "checksum lazy_static 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a374c89b9db55895453a74c1e38861d9deec0b01b405a82516e9d5de4820dea1" "checksum libc 0.2.58 (registry+https://github.com/rust-lang/crates.io-index)" = "6281b86796ba5e4366000be6e9e18bf35580adf9e63fbe2294aadb587613a319" "checksum log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)" = "14b6052be84e6b71ab17edffc2eeabf5c2c3ae1fdb464aae35ac50c67a44e1f7" @@ -652,8 +720,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum redox_syscall 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "cf8fb82a4d1c9b28f1c26c574a5b541f5ffb4315f6c9a791fa47b6a04438fe93" "checksum redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7e891cfe48e9100a70a3b6eb652fef28920c117d366339687bd5576160db0f76" "checksum regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "dc220bd33bdce8f093101afe22a037b8eb0e5af33592e6a9caafff0d4cb81cbd" +"checksum regex-automata 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "92b73c2a1770c255c240eaa4ee600df1704a38dc3feaa6e949e7fcd4f8dc09f9" "checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716" +"checksum ryu 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c92464b447c0ee8c4fb3824ecc8383b81717b9f1e74ba2e72540aef7b9f82997" "checksum scopeguard 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "94258f53601af11e6a49f722422f6e3425c52b06245a5cf9bc09908b174f5e27" +"checksum serde 1.0.101 (registry+https://github.com/rust-lang/crates.io-index)" = "9796c9b7ba2ffe7a9ce53c2287dfc48080f4b2b362fcc245a259b3a7201119dd" "checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" "checksum structopt 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3fe8d3289b63ef2f196d89e7701f986583c0895e764b78f052a55b9b5d34d84a" "checksum structopt-derive 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f3add731f5b4fb85931d362a3c92deb1ad7113649a8d51701fb257673705f122" diff --git a/Cargo.toml b/Cargo.toml index 3c50d53..1b39ea9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2018" [dependencies] clap = { version = "2.33.0", features = ["yaml"] } +csv = "1.0.0" env_logger = "0.7.0" flate2 = "1.0.11" indicatif = "0.12.0" @@ -14,5 +15,5 @@ pad = "0.1.5" rand = "0.7.2" rand_chacha = "0.2.1" rayon = "1.2.0" -structopt = "0.3.2" +structopt = "0.3.0" try_from = "0.3.2" diff --git a/src/cli.rs b/src/cli.rs index 89b1cae..bac11bc 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -35,22 +35,22 @@ pub struct Split { #[structopt( short = "r", long = "rows", - required_unless = "prop_splits", - conflicts_with = "prop_splits", + required_unless = "prop", + conflicts_with = "prop", help = "Specify splits by number of rows", use_delimiter = true )] - pub row_splits: Vec, + pub rows: Vec, #[structopt( short = "p", long = "prop", - required_unless = "row_splits", - conflicts_with = "row_splits", + required_unless = "rows", + conflicts_with = "rows", help = "Specify splits by proportion of rows", use_delimiter = true )] - pub prop_splits: Vec, + pub prop: Vec, #[structopt( short = "c", @@ -69,6 +69,12 @@ pub struct Split { #[structopt(short = "s", long = "seed", help = "RNG seed, for reproducibility")] pub seed: Option, + #[structopt( + long = "csv", + help = "Parse input as CSV. Only needed if rows contain embedded newlines - will impact performance." + )] + pub csv: bool, + #[structopt( parse(from_os_str), help = "Data to split, optionally gzip compressed. If '-', read from stdin" diff --git a/src/error.rs b/src/error.rs index 84bf55d..086e822 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,7 @@ pub enum Error { ProportionTooLow(String), ProportionTooHigh(String), + CsvError(csv::Error), IoError(std::io::Error), ParseFloatError(std::num::ParseFloatError), ParseIntError(std::num::ParseIntError), @@ -38,6 +39,12 @@ impl From for Error { } } +impl From for Error { + fn from(error: csv::Error) -> Self { + Error::CsvError(error) + } +} + impl From> for Error { fn from(error: std::sync::mpsc::SendError) -> Self { Error::SendError(error) diff --git a/src/io.rs b/src/io.rs index 02a5ef2..7e14afa 100644 --- a/src/io.rs +++ b/src/io.rs @@ -7,7 +7,6 @@ use flate2::write::GzEncoder; use crate::error::Result; -pub type InputReader = BufReader>; pub type OutputWriter = Box; #[derive(Clone, Copy, Debug)] @@ -16,7 +15,43 @@ pub enum Compression { GzipCompression, } -pub fn open_data>(path: P, compression: Compression) -> Result { +pub trait LineReader { + fn read_line(&mut self) -> Option>; +} + +impl LineReader for csv::Reader> { + fn read_line(&mut self) -> Option> { + let mut record = csv::ByteRecord::with_capacity(1024, 100); + match self.read_byte_record(&mut record) { + Ok(read) if read => { + let curs = std::io::Cursor::new(Vec::with_capacity(1024)); + let mut writer = csv::Writer::from_writer(curs); + writer.write_byte_record(&record).unwrap(); + let s = String::from_utf8(writer.into_inner().unwrap().into_inner()).unwrap(); + Some(Ok(s)) + } + Ok(_) => None, + Err(e) => Some(Err(e.into())), + } + } +} + +impl LineReader for BufReader> { + fn read_line(&mut self) -> Option> { + let mut buf = String::with_capacity(1024); + match std::io::BufRead::read_line(self, &mut buf) { + Ok(n) if n == 0 => None, + Ok(_) => Some(Ok(buf)), + Err(e) => Some(Err(e.into())), + } + } +} + +pub fn open_data>( + path: P, + compression: Compression, + csv_builder: Option, +) -> Result> { // Read from stdin if input is '-', else try to open the provided file. let reader: Box = match path.as_ref().to_str() { Some(p) if p == "-" => Box::new(std::io::stdin()), @@ -28,7 +63,12 @@ pub fn open_data>(path: P, compression: Compression) -> Result reader, Compression::GzipCompression => Box::new(GzDecoder::new(reader)), }; - Ok(BufReader::with_capacity(1024 * 1024, reader)) + + let reader: Box = match csv_builder { + Some(builder) => Box::new(builder.from_reader(reader)), + None => Box::new(BufReader::with_capacity(1024 * 1024, reader)), + }; + Ok(reader) } pub fn open_output>(path: P, compression: Compression) -> Result { diff --git a/src/main.rs b/src/main.rs index d041e4c..b1acdda 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,13 +8,16 @@ fn main() -> Result<()> { let opt = cli::Opt::from_args(); match opt.cmd { cli::Command::Split(x) => { - let mut splitter = SplitterBuilder::new(&x.input, x.row_splits, x.prop_splits)?; + let mut splitter = SplitterBuilder::new(&x.input, x.rows, x.prop)?; if x.uncompressed { splitter = splitter.input_compression(Compression::Uncompressed); } if x.uncompressed_output { splitter = splitter.output_compression(Compression::Uncompressed); } + if x.csv { + splitter = splitter.csv(true); + } if let Some(seed) = x.seed { splitter = splitter.seed(seed); } diff --git a/src/split/splitter.rs b/src/split/splitter.rs index f41331c..1d3076c 100644 --- a/src/split/splitter.rs +++ b/src/split/splitter.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::io::BufRead; use std::path::{Path, PathBuf}; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; @@ -33,6 +32,8 @@ pub struct SplitterBuilder { input_compression: Compression, /// Compression for output files output_compression: Compression, + /// Is the input CSV? + csv: bool, } impl SplitterBuilder { @@ -55,6 +56,7 @@ impl SplitterBuilder { total_rows: None, input_compression: Compression::GzipCompression, output_compression: Compression::GzipCompression, + csv: false, }) } @@ -88,6 +90,11 @@ impl SplitterBuilder { self } + pub fn csv(mut self, csv: bool) -> Self { + self.csv = csv; + self + } + pub fn build(self) -> Result { let rng = match self.seed { Some(s) => ChaChaRng::seed_from_u64(s), @@ -102,6 +109,7 @@ impl SplitterBuilder { total_rows: self.total_rows, input_compression: self.input_compression, output_compression: self.output_compression, + csv: self.csv, }) } } @@ -123,6 +131,8 @@ pub struct Splitter { input_compression: Compression, /// Compression for output files output_compression: Compression, + /// Is the input CSV? + csv: bool, } impl Splitter { @@ -222,13 +232,19 @@ impl Splitter { pool.scope(move |scope| { info!("Reading data from {}", self.input.to_str().unwrap()); - let reader = open_data(&self.input, self.input_compression)?; + let reader_builder = if self.csv { + let mut reader_builder = csv::ReaderBuilder::new(); + reader_builder.has_headers(false); + Some(reader_builder) + } else { + None + }; + let mut reader = open_data(&self.input, self.input_compression, reader_builder)?; info!("Writing header to files"); - let mut lines = reader.lines(); - let header = match lines.next() { + let header = match reader.read_line() { + Some(h) => h?, None => return Err(Error::EmptyFile), - Some(res) => res?, }; for sender in senders.values_mut() { sender.send_all(&header)?; @@ -252,7 +268,8 @@ impl Splitter { header = Some(row.clone()); } if let Some(chunk_size) = writer.chunk_size { - if rows_sent_to_chunk > (chunk_size + 1) { // add one for header + if rows_sent_to_chunk > (chunk_size + 1) { + // add one for header // This should only ever happen if we weren't // able to pre-calculate how many chunks were // needed @@ -274,7 +291,7 @@ impl Splitter { } info!("Reading lines"); - for record in lines { + while let Some(record) = reader.read_line() { let split = self.splits.get_split(&mut self.rng); match split { SplitSelection::Some(split) => { diff --git a/src/split/writer.rs b/src/split/writer.rs index acee6a9..53501cb 100644 --- a/src/split/writer.rs +++ b/src/split/writer.rs @@ -174,7 +174,6 @@ impl ChunkWriter { /// Handle writing of a row to this chunk. pub fn handle_row(&self, file: &mut io::OutputWriter, row: &str) -> Result<()> { file.write_all(row.as_bytes())?; - file.write_all(b"\n")?; Ok(()) } }