Skip to content

Commit 40df55f

Browse files
authored
use clap 3 style args parsing for datafusion cli (#1749)
* use clap 3 style args parsing for datafusion cli * upgrade cli version
1 parent 15cfcbc commit 40df55f

File tree

7 files changed

+80
-179
lines changed

7 files changed

+80
-179
lines changed

datafusion-cli/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
[package]
1919
name = "datafusion-cli"
20-
version = "5.1.0"
20+
description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model. It supports executing SQL queries against CSV and Parquet files as well as querying directly against in-memory data."
21+
version = "6.0.0"
2122
authors = ["Apache Arrow <dev@arrow.apache.org>"]
2223
edition = "2021"
2324
keywords = [ "arrow", "datafusion", "ballista", "query", "sql" ]

datafusion-cli/src/command.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
use crate::context::Context;
2121
use crate::functions::{display_all_functions, Function};
2222
use crate::print_format::PrintFormat;
23-
use crate::print_options::{self, PrintOptions};
23+
use crate::print_options::PrintOptions;
24+
use clap::ArgEnum;
2425
use datafusion::arrow::array::{ArrayRef, StringArray};
2526
use datafusion::arrow::datatypes::{DataType, Field, Schema};
2627
use datafusion::arrow::record_batch::RecordBatch;
@@ -206,10 +207,14 @@ impl OutputFormat {
206207
Self::ChangeFormat(format) => {
207208
if let Ok(format) = format.parse::<PrintFormat>() {
208209
print_options.format = format;
209-
println!("Output format is {}.", print_options.format);
210+
println!("Output format is {:?}.", print_options.format);
210211
Ok(())
211212
} else {
212-
Err(DataFusionError::Execution(format!("{} is not a valid format type [possible values: csv, tsv, table, json, ndjson]", format)))
213+
Err(DataFusionError::Execution(format!(
214+
"{:?} is not a valid format type [possible values: {:?}]",
215+
format,
216+
PrintFormat::value_variants()
217+
)))
213218
}
214219
}
215220
}

datafusion-cli/src/exec.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,14 @@ use crate::{
2121
command::{Command, OutputFormat},
2222
context::Context,
2323
helper::CliHelper,
24-
print_format::{all_print_formats, PrintFormat},
2524
print_options::PrintOptions,
2625
};
27-
use datafusion::arrow::record_batch::RecordBatch;
28-
use datafusion::arrow::util::pretty;
29-
use datafusion::error::{DataFusionError, Result};
30-
use rustyline::config::Config;
26+
use datafusion::error::Result;
3127
use rustyline::error::ReadlineError;
3228
use rustyline::Editor;
3329
use std::fs::File;
3430
use std::io::prelude::*;
3531
use std::io::BufReader;
36-
use std::str::FromStr;
37-
use std::sync::Arc;
3832
use std::time::Instant;
3933

4034
/// run and execute SQL statements and commands from a file, against a context with the given print options
@@ -109,7 +103,7 @@ pub async fn exec_from_repl(ctx: &mut Context, print_options: &mut PrintOptions)
109103
);
110104
}
111105
} else {
112-
println!("Output format is {}.", print_options.format);
106+
println!("Output format is {:?}.", print_options.format);
113107
}
114108
}
115109
_ => {

datafusion-cli/src/functions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use arrow::array::StringArray;
2020
use arrow::datatypes::{DataType, Field, Schema};
2121
use arrow::record_batch::RecordBatch;
2222
use arrow::util::pretty::pretty_format_batches;
23-
use datafusion::error::{DataFusionError, Result};
23+
use datafusion::error::Result;
2424
use std::fmt;
2525
use std::str::FromStr;
2626
use std::sync::Arc;

datafusion-cli/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
// under the License.
1717

1818
#![doc = include_str!("../README.md")]
19-
#![allow(unused_imports)]
2019
pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION");
2120

2221
pub mod command;

datafusion-cli/src/main.rs

Lines changed: 63 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -15,132 +15,96 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use clap::{crate_version, App, Arg};
18+
use clap::Parser;
1919
use datafusion::error::Result;
2020
use datafusion::execution::context::ExecutionConfig;
2121
use datafusion_cli::{
22-
context::Context,
23-
exec,
24-
print_format::{all_print_formats, PrintFormat},
25-
print_options::PrintOptions,
22+
context::Context, exec, print_format::PrintFormat, print_options::PrintOptions,
2623
DATAFUSION_CLI_VERSION,
2724
};
2825
use std::env;
2926
use std::fs::File;
3027
use std::io::BufReader;
3128
use std::path::Path;
3229

30+
#[derive(Debug, Parser, PartialEq)]
31+
#[clap(author, version, about, long_about= None)]
32+
struct Args {
33+
#[clap(
34+
short = 'p',
35+
long,
36+
help = "Path to your data, default to current directory",
37+
validator(is_valid_data_dir)
38+
)]
39+
data_path: Option<String>,
40+
41+
#[clap(
42+
short = 'c',
43+
long,
44+
help = "The batch size of each query, or use DataFusion default",
45+
validator(is_valid_batch_size)
46+
)]
47+
batch_size: Option<usize>,
48+
49+
#[clap(
50+
short,
51+
long,
52+
multiple_values = true,
53+
help = "Execute commands from file(s), then exit",
54+
validator(is_valid_file)
55+
)]
56+
file: Vec<String>,
57+
58+
#[clap(long, arg_enum, default_value_t = PrintFormat::Table)]
59+
format: PrintFormat,
60+
61+
#[clap(long, help = "Ballista scheduler host")]
62+
host: Option<String>,
63+
64+
#[clap(long, help = "Ballista scheduler port")]
65+
port: Option<u16>,
66+
67+
#[clap(
68+
short,
69+
long,
70+
help = "Reduce printing other than the results and work quietly"
71+
)]
72+
quiet: bool,
73+
}
74+
3375
#[tokio::main]
3476
pub async fn main() -> Result<()> {
35-
let matches = App::new("DataFusion")
36-
.version(crate_version!())
37-
.about(
38-
"DataFusion is an in-memory query engine that uses Apache Arrow \
39-
as the memory model. It supports executing SQL queries against CSV and \
40-
Parquet files as well as querying directly against in-memory data.",
41-
)
42-
.arg(
43-
Arg::new("data-path")
44-
.help("Path to your data, default to current directory")
45-
.short('p')
46-
.long("data-path")
47-
.validator(is_valid_data_dir)
48-
.takes_value(true),
49-
)
50-
.arg(
51-
Arg::new("batch-size")
52-
.help("The batch size of each query, or use DataFusion default")
53-
.short('c')
54-
.long("batch-size")
55-
.validator(is_valid_batch_size)
56-
.takes_value(true),
57-
)
58-
.arg(
59-
Arg::new("file")
60-
.help("Execute commands from file(s), then exit")
61-
.short('f')
62-
.long("file")
63-
.multiple_occurrences(true)
64-
.validator(is_valid_file)
65-
.takes_value(true),
66-
)
67-
.arg(
68-
Arg::new("format")
69-
.help("Output format")
70-
.long("format")
71-
.default_value("table")
72-
.possible_values(
73-
&all_print_formats()
74-
.iter()
75-
.map(|format| format.to_string())
76-
.collect::<Vec<_>>()
77-
.iter()
78-
.map(|i| i.as_str())
79-
.collect::<Vec<_>>(),
80-
)
81-
.takes_value(true),
82-
)
83-
.arg(
84-
Arg::new("host")
85-
.help("Ballista scheduler host")
86-
.long("host")
87-
.takes_value(true),
88-
)
89-
.arg(
90-
Arg::new("port")
91-
.help("Ballista scheduler port")
92-
.long("port")
93-
.takes_value(true),
94-
)
95-
.arg(
96-
Arg::new("quiet")
97-
.help("Reduce printing other than the results and work quietly")
98-
.short('q')
99-
.long("quiet")
100-
.takes_value(false),
101-
)
102-
.get_matches();
103-
104-
let quiet = matches.is_present("quiet");
105-
106-
if !quiet {
107-
println!("DataFusion CLI v{}\n", DATAFUSION_CLI_VERSION);
108-
}
77+
let args = Args::parse();
10978

110-
let host = matches.value_of("host");
111-
let port = matches
112-
.value_of("port")
113-
.and_then(|port| port.parse::<u16>().ok());
79+
if !args.quiet {
80+
println!("DataFusion CLI v{}", DATAFUSION_CLI_VERSION);
81+
}
11482

115-
if let Some(path) = matches.value_of("data-path") {
83+
if let Some(ref path) = args.data_path {
11684
let p = Path::new(path);
11785
env::set_current_dir(&p).unwrap();
11886
};
11987

12088
let mut execution_config = ExecutionConfig::new().with_information_schema(true);
12189

122-
if let Some(batch_size) = matches
123-
.value_of("batch-size")
124-
.and_then(|size| size.parse::<usize>().ok())
125-
{
90+
if let Some(batch_size) = args.batch_size {
12691
execution_config = execution_config.with_batch_size(batch_size);
12792
};
12893

129-
let mut ctx: Context = match (host, port) {
130-
(Some(h), Some(p)) => Context::new_remote(h, p)?,
94+
let mut ctx: Context = match (args.host, args.port) {
95+
(Some(ref h), Some(p)) => Context::new_remote(h, p)?,
13196
_ => Context::new_local(&execution_config),
13297
};
13398

134-
let format = matches
135-
.value_of("format")
136-
.expect("No format is specified")
137-
.parse::<PrintFormat>()
138-
.expect("Invalid format");
139-
140-
let mut print_options = PrintOptions { format, quiet };
99+
let mut print_options = PrintOptions {
100+
format: args.format,
101+
quiet: args.quiet,
102+
};
141103

142-
if let Some(file_paths) = matches.values_of("file") {
143-
let files = file_paths
104+
let files = args.file;
105+
if !files.is_empty() {
106+
let files = files
107+
.into_iter()
144108
.map(|file_path| File::open(file_path).unwrap())
145109
.collect::<Vec<_>>();
146110
for file in files {

datafusion-cli/src/print_format.rs

Lines changed: 4 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@ use arrow::json::{ArrayWriter, LineDelimitedWriter};
2121
use datafusion::arrow::record_batch::RecordBatch;
2222
use datafusion::arrow::util::pretty;
2323
use datafusion::error::{DataFusionError, Result};
24-
use std::fmt;
2524
use std::str::FromStr;
2625

2726
/// Allow records to be printed in different formats
28-
#[derive(Debug, PartialEq, Eq, Clone)]
27+
#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone)]
2928
pub enum PrintFormat {
3029
Csv,
3130
Tsv,
@@ -34,40 +33,11 @@ pub enum PrintFormat {
3433
NdJson,
3534
}
3635

37-
/// returns all print formats
38-
pub fn all_print_formats() -> Vec<PrintFormat> {
39-
vec![
40-
PrintFormat::Csv,
41-
PrintFormat::Tsv,
42-
PrintFormat::Table,
43-
PrintFormat::Json,
44-
PrintFormat::NdJson,
45-
]
46-
}
47-
4836
impl FromStr for PrintFormat {
49-
type Err = ();
50-
fn from_str(s: &str) -> std::result::Result<Self, ()> {
51-
match s.to_lowercase().as_str() {
52-
"csv" => Ok(Self::Csv),
53-
"tsv" => Ok(Self::Tsv),
54-
"table" => Ok(Self::Table),
55-
"json" => Ok(Self::Json),
56-
"ndjson" => Ok(Self::NdJson),
57-
_ => Err(()),
58-
}
59-
}
60-
}
37+
type Err = String;
6138

62-
impl fmt::Display for PrintFormat {
63-
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64-
match *self {
65-
Self::Csv => write!(f, "csv"),
66-
Self::Tsv => write!(f, "tsv"),
67-
Self::Table => write!(f, "table"),
68-
Self::Json => write!(f, "json"),
69-
Self::NdJson => write!(f, "ndjson"),
70-
}
39+
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
40+
clap::ArgEnum::from_str(s, true)
7141
}
7242
}
7343

@@ -123,38 +93,6 @@ mod tests {
12393
use datafusion::from_slice::FromSlice;
12494
use std::sync::Arc;
12595

126-
#[test]
127-
fn test_from_str() {
128-
let format = "csv".parse::<PrintFormat>().unwrap();
129-
assert_eq!(PrintFormat::Csv, format);
130-
131-
let format = "tsv".parse::<PrintFormat>().unwrap();
132-
assert_eq!(PrintFormat::Tsv, format);
133-
134-
let format = "json".parse::<PrintFormat>().unwrap();
135-
assert_eq!(PrintFormat::Json, format);
136-
137-
let format = "ndjson".parse::<PrintFormat>().unwrap();
138-
assert_eq!(PrintFormat::NdJson, format);
139-
140-
let format = "table".parse::<PrintFormat>().unwrap();
141-
assert_eq!(PrintFormat::Table, format);
142-
}
143-
144-
#[test]
145-
fn test_to_str() {
146-
assert_eq!("csv", PrintFormat::Csv.to_string());
147-
assert_eq!("table", PrintFormat::Table.to_string());
148-
assert_eq!("tsv", PrintFormat::Tsv.to_string());
149-
assert_eq!("json", PrintFormat::Json.to_string());
150-
assert_eq!("ndjson", PrintFormat::NdJson.to_string());
151-
}
152-
153-
#[test]
154-
fn test_from_str_failure() {
155-
assert!("pretty".parse::<PrintFormat>().is_err());
156-
}
157-
15896
#[test]
15997
fn test_print_batches_with_sep() {
16098
let batches = vec![];

0 commit comments

Comments
 (0)