Skip to content

Commit

Permalink
Merge pull request #37 from sd2k/16-parse-rows-as-csv
Browse files Browse the repository at this point in the history
Add flag to parse files as CSV
  • Loading branch information
sd2k authored Sep 24, 2019
2 parents cb481ed + 34f1b85 commit 1efc115
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 19 deletions.
71 changes: 71 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ edition = "2018"

[dependencies]
clap = { version = "2.33.0", features = ["yaml"] }
csv = "1.0.0"
env_logger = "0.7.0"
flate2 = "1.0.11"
indicatif = "0.12.0"
Expand All @@ -14,5 +15,5 @@ pad = "0.1.5"
rand = "0.7.2"
rand_chacha = "0.2.1"
rayon = "1.2.0"
structopt = "0.3.2"
structopt = "0.3.0"
try_from = "0.3.2"
18 changes: 12 additions & 6 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,22 @@ pub struct Split {
#[structopt(
short = "r",
long = "rows",
required_unless = "prop_splits",
conflicts_with = "prop_splits",
required_unless = "prop",
conflicts_with = "prop",
help = "Specify splits by number of rows",
use_delimiter = true
)]
pub row_splits: Vec<RowSplit>,
pub rows: Vec<RowSplit>,

#[structopt(
short = "p",
long = "prop",
required_unless = "row_splits",
conflicts_with = "row_splits",
required_unless = "rows",
conflicts_with = "rows",
help = "Specify splits by proportion of rows",
use_delimiter = true
)]
pub prop_splits: Vec<ProportionSplit>,
pub prop: Vec<ProportionSplit>,

#[structopt(
short = "c",
Expand All @@ -69,6 +69,12 @@ pub struct Split {
#[structopt(short = "s", long = "seed", help = "RNG seed, for reproducibility")]
pub seed: Option<u64>,

#[structopt(
long = "csv",
help = "Parse input as CSV. Only needed if rows contain embedded newlines - will impact performance."
)]
pub csv: bool,

#[structopt(
parse(from_os_str),
help = "Data to split, optionally gzip compressed. If '-', read from stdin"
Expand Down
7 changes: 7 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub enum Error {
ProportionTooLow(String),
ProportionTooHigh(String),

CsvError(csv::Error),
IoError(std::io::Error),
ParseFloatError(std::num::ParseFloatError),
ParseIntError(std::num::ParseIntError),
Expand All @@ -38,6 +39,12 @@ impl From<std::io::Error> for Error {
}
}

impl From<csv::Error> for Error {
fn from(error: csv::Error) -> Self {
Error::CsvError(error)
}
}

impl From<std::sync::mpsc::SendError<String>> for Error {
fn from(error: std::sync::mpsc::SendError<String>) -> Self {
Error::SendError(error)
Expand Down
46 changes: 43 additions & 3 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use flate2::write::GzEncoder;

use crate::error::Result;

pub type InputReader = BufReader<Box<dyn Read>>;
pub type OutputWriter = Box<dyn Write>;

#[derive(Clone, Copy, Debug)]
Expand All @@ -16,7 +15,43 @@ pub enum Compression {
GzipCompression,
}

pub fn open_data<P: AsRef<Path>>(path: P, compression: Compression) -> Result<InputReader> {
pub trait LineReader {
fn read_line(&mut self) -> Option<Result<String>>;
}

impl LineReader for csv::Reader<Box<dyn Read>> {
fn read_line(&mut self) -> Option<Result<String>> {
let mut record = csv::ByteRecord::with_capacity(1024, 100);
match self.read_byte_record(&mut record) {
Ok(read) if read => {
let curs = std::io::Cursor::new(Vec::with_capacity(1024));
let mut writer = csv::Writer::from_writer(curs);
writer.write_byte_record(&record).unwrap();
let s = String::from_utf8(writer.into_inner().unwrap().into_inner()).unwrap();
Some(Ok(s))
}
Ok(_) => None,
Err(e) => Some(Err(e.into())),
}
}
}

impl LineReader for BufReader<Box<dyn Read>> {
fn read_line(&mut self) -> Option<Result<String>> {
let mut buf = String::with_capacity(1024);
match std::io::BufRead::read_line(self, &mut buf) {
Ok(n) if n == 0 => None,
Ok(_) => Some(Ok(buf)),
Err(e) => Some(Err(e.into())),
}
}
}

pub fn open_data<P: AsRef<Path>>(
path: P,
compression: Compression,
csv_builder: Option<csv::ReaderBuilder>,
) -> Result<Box<dyn LineReader>> {
// Read from stdin if input is '-', else try to open the provided file.
let reader: Box<dyn Read> = match path.as_ref().to_str() {
Some(p) if p == "-" => Box::new(std::io::stdin()),
Expand All @@ -28,7 +63,12 @@ pub fn open_data<P: AsRef<Path>>(path: P, compression: Compression) -> Result<In
Compression::Uncompressed => reader,
Compression::GzipCompression => Box::new(GzDecoder::new(reader)),
};
Ok(BufReader::with_capacity(1024 * 1024, reader))

let reader: Box<dyn LineReader> = match csv_builder {
Some(builder) => Box::new(builder.from_reader(reader)),
None => Box::new(BufReader::with_capacity(1024 * 1024, reader)),
};
Ok(reader)
}

pub fn open_output<P: AsRef<Path>>(path: P, compression: Compression) -> Result<OutputWriter> {
Expand Down
5 changes: 4 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ fn main() -> Result<()> {
let opt = cli::Opt::from_args();
match opt.cmd {
cli::Command::Split(x) => {
let mut splitter = SplitterBuilder::new(&x.input, x.row_splits, x.prop_splits)?;
let mut splitter = SplitterBuilder::new(&x.input, x.rows, x.prop)?;
if x.uncompressed {
splitter = splitter.input_compression(Compression::Uncompressed);
}
if x.uncompressed_output {
splitter = splitter.output_compression(Compression::Uncompressed);
}
if x.csv {
splitter = splitter.csv(true);
}
if let Some(seed) = x.seed {
splitter = splitter.seed(seed);
}
Expand Down
Loading

0 comments on commit 1efc115

Please sign in to comment.