Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions datafusion-examples/examples/custom_file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ impl FileFormat for TSVFileFormat {
}
}

fn compression_type(&self) -> Option<FileCompressionType> {
None
}

async fn infer_schema(
&self,
state: &dyn Session,
Expand Down
1 change: 1 addition & 0 deletions datafusion-examples/examples/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ use tempfile::tempdir;
/// * [query_to_date]: execute queries against parquet files
#[tokio::main]
async fn main() -> Result<()> {
env_logger::init();
// The SessionContext is the main high level API for interacting with DataFusion
let ctx = SessionContext::new();
read_parquet(&ctx).await?;
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ datafusion-macros = { workspace = true }
datafusion-physical-optimizer = { workspace = true }
doc-comment = { workspace = true }
env_logger = { workspace = true }
glob = { version = "0.3.0" }
insta = { workspace = true }
paste = "^1.0"
rand = { workspace = true, features = ["small_rng"] }
Expand Down
4 changes: 4 additions & 0 deletions datafusion/core/src/datasource/file_format/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ impl FileFormat for ArrowFormat {
}
}

fn compression_type(&self) -> Option<FileCompressionType> {
None
}

async fn infer_schema(
&self,
_state: &dyn Session,
Expand Down
81 changes: 81 additions & 0 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ mod tests {
use async_trait::async_trait;
use bytes::Bytes;
use chrono::DateTime;
use datafusion_common::parsers::CompressionTypeVariant;
use futures::stream::BoxStream;
use futures::StreamExt;
use insta::assert_snapshot;
Expand Down Expand Up @@ -877,6 +878,86 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_csv_extension_compressed() -> Result<()> {
// Write compressed CSV files
// Expect: under the directory, a file is created with ".csv.gz" extension
let ctx = SessionContext::new();

let df = ctx
.read_csv(
&format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
CsvReadOptions::default().has_header(true),
)
.await?;

let tmp_dir = tempfile::TempDir::new().unwrap();
let path = format!("{}", tmp_dir.path().to_string_lossy());

let cfg1 = crate::dataframe::DataFrameWriteOptions::new();
let cfg2 = CsvOptions::default()
.with_has_header(true)
.with_compression(CompressionTypeVariant::GZIP);

df.write_csv(&path, cfg1, Some(cfg2)).await?;
assert!(std::path::Path::new(&path).exists());

let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect();
assert_eq!(files.len(), 1);
assert!(files
.last()
.unwrap()
.as_ref()
.unwrap()
.path()
.file_name()
.unwrap()
.to_str()
.unwrap()
.ends_with(".csv.gz"));

Ok(())
}

#[tokio::test]
async fn test_csv_extension_uncompressed() -> Result<()> {
// Write plain uncompressed CSV files
// Expect: under the directory, a file is created with ".csv" extension
let ctx = SessionContext::new();

let df = ctx
.read_csv(
&format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
CsvReadOptions::default().has_header(true),
)
.await?;

let tmp_dir = tempfile::TempDir::new().unwrap();
let path = format!("{}", tmp_dir.path().to_string_lossy());

let cfg1 = crate::dataframe::DataFrameWriteOptions::new();
let cfg2 = CsvOptions::default().with_has_header(true);

df.write_csv(&path, cfg1, Some(cfg2)).await?;
assert!(std::path::Path::new(&path).exists());

let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect();
assert_eq!(files.len(), 1);
assert!(files
.last()
.unwrap()
.as_ref()
.unwrap()
.path()
.file_name()
.unwrap()
.to_str()
.unwrap()
.ends_with(".csv"));

Ok(())
}

/// Read multiple empty csv files
///
/// all_empty
Expand Down
117 changes: 114 additions & 3 deletions datafusion/core/src/datasource/listing_table_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,21 @@ impl TableProviderFactory for ListingTableFactory {
// if the folder then rewrite a file path as 'path/*.parquet'
// to only read the files the reader can understand
if table_path.is_folder() && table_path.get_glob().is_none() {
table_path = table_path.with_glob(
format!("*.{}", cmd.file_type.to_lowercase()).as_ref(),
)?;
// Since there are no files yet to infer an actual extension,
// derive the pattern based on compression type.
// So for gzipped CSV the pattern is `*.csv.gz`
let glob = match options.format.compression_type() {
Some(compression) => {
match options.format.get_ext_with_compression(&compression) {
// Use glob based on `FileFormat` extension
Ok(ext) => format!("*.{ext}"),
// Fallback to `file_type`, if not supported by `FileFormat`
Err(_) => format!("*.{}", cmd.file_type.to_lowercase()),
}
}
None => format!("*.{}", cmd.file_type.to_lowercase()),
};
table_path = table_path.with_glob(glob.as_ref())?;
}
let schema = options.infer_schema(session_state, &table_path).await?;
let df_schema = Arc::clone(&schema).to_dfschema()?;
Expand Down Expand Up @@ -175,13 +187,15 @@ fn get_extension(path: &str) -> String {

#[cfg(test)]
mod tests {
use glob::Pattern;
use std::collections::HashMap;

use super::*;
use crate::{
datasource::file_format::csv::CsvFormat, execution::context::SessionContext,
};

use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::{Constraints, DFSchema, TableReference};

#[tokio::test]
Expand Down Expand Up @@ -264,4 +278,101 @@ mod tests {
let listing_options = listing_table.options();
assert_eq!(".tbl", listing_options.file_extension);
}

/// Validates that CreateExternalTable with compression
/// searches for gzipped files in a directory location
#[tokio::test]
async fn test_create_using_folder_with_compression() {
let dir = tempfile::tempdir().unwrap();

let factory = ListingTableFactory::new();
let context = SessionContext::new();
let state = context.state();
let name = TableReference::bare("foo");

let mut options = HashMap::new();
options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
options.insert("format.has_header".into(), "true".into());
options.insert("format.compression".into(), "gzip".into());
let cmd = CreateExternalTable {
name,
location: dir.path().to_str().unwrap().to_string(),
file_type: "csv".to_string(),
schema: Arc::new(DFSchema::empty()),
table_partition_cols: vec![],
if_not_exists: false,
temporary: false,
definition: None,
order_exprs: vec![],
unbounded: false,
options,
constraints: Constraints::default(),
column_defaults: HashMap::new(),
};
let table_provider = factory.create(&state, &cmd).await.unwrap();
let listing_table = table_provider
.as_any()
.downcast_ref::<ListingTable>()
.unwrap();

// Verify compression is used
let format = listing_table.options().format.clone();
let csv_format = format.as_any().downcast_ref::<CsvFormat>().unwrap();
let csv_options = csv_format.options().clone();
assert_eq!(csv_options.compression, CompressionTypeVariant::GZIP);

let listing_options = listing_table.options();
assert_eq!("", listing_options.file_extension);
// Glob pattern is set to search for gzipped files
let table_path = listing_table.table_paths().first().unwrap();
assert_eq!(
table_path.get_glob().clone().unwrap(),
Pattern::new("*.csv.gz").unwrap()
);
}

/// Validates that CreateExternalTable without compression
/// searches for normal files in a directory location
#[tokio::test]
async fn test_create_using_folder_without_compression() {
let dir = tempfile::tempdir().unwrap();

let factory = ListingTableFactory::new();
let context = SessionContext::new();
let state = context.state();
let name = TableReference::bare("foo");

let mut options = HashMap::new();
options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
options.insert("format.has_header".into(), "true".into());
let cmd = CreateExternalTable {
name,
location: dir.path().to_str().unwrap().to_string(),
file_type: "csv".to_string(),
schema: Arc::new(DFSchema::empty()),
table_partition_cols: vec![],
if_not_exists: false,
temporary: false,
definition: None,
order_exprs: vec![],
unbounded: false,
options,
constraints: Constraints::default(),
column_defaults: HashMap::new(),
};
let table_provider = factory.create(&state, &cmd).await.unwrap();
let listing_table = table_provider
.as_any()
.downcast_ref::<ListingTable>()
.unwrap();

let listing_options = listing_table.options();
assert_eq!("", listing_options.file_extension);
// Glob pattern is set to search for gzipped files
let table_path = listing_table.table_paths().first().unwrap();
assert_eq!(
table_path.get_glob().clone().unwrap(),
Pattern::new("*.csv").unwrap()
);
}
}
10 changes: 9 additions & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,14 @@ impl DefaultPhysicalPlanner {
let sink_format = file_type_to_format(file_type)?
.create(session_state, source_option_tuples)?;

// Determine extension based on format extension and compression
let file_extension = match sink_format.compression_type() {
Some(compression_type) => sink_format
.get_ext_with_compression(&compression_type)
.unwrap_or_else(|_| sink_format.get_ext()),
None => sink_format.get_ext(),
};

// Set file sink related options
let config = FileSinkConfig {
original_url,
Expand All @@ -543,7 +551,7 @@ impl DefaultPhysicalPlanner {
table_partition_cols,
insert_op: InsertOp::Append,
keep_partition_by_columns,
file_extension: sink_format.get_ext(),
file_extension,
};

sink_format
Expand Down
4 changes: 4 additions & 0 deletions datafusion/datasource-avro/src/file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ impl FileFormat for AvroFormat {
}
}

fn compression_type(&self) -> Option<FileCompressionType> {
None
}

async fn infer_schema(
&self,
_state: &dyn Session,
Expand Down
4 changes: 4 additions & 0 deletions datafusion/datasource-csv/src/file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ impl FileFormat for CsvFormat {
Ok(format!("{}{}", ext, file_compression_type.get_ext()))
}

fn compression_type(&self) -> Option<FileCompressionType> {
Some(self.options.compression.into())
}

async fn infer_schema(
&self,
state: &dyn Session,
Expand Down
4 changes: 4 additions & 0 deletions datafusion/datasource-json/src/file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ impl FileFormat for JsonFormat {
Ok(format!("{}{}", ext, file_compression_type.get_ext()))
}

fn compression_type(&self) -> Option<FileCompressionType> {
Some(self.options.compression.into())
}

async fn infer_schema(
&self,
_state: &dyn Session,
Expand Down
4 changes: 4 additions & 0 deletions datafusion/datasource-parquet/src/file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ impl FileFormat for ParquetFormat {
}
}

fn compression_type(&self) -> Option<FileCompressionType> {
None
}

async fn infer_schema(
&self,
state: &dyn Session,
Expand Down
3 changes: 3 additions & 0 deletions datafusion/datasource/src/file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ pub trait FileFormat: Send + Sync + fmt::Debug {
_file_compression_type: &FileCompressionType,
) -> Result<String>;

/// Returns whether this instance uses compression if applicable
fn compression_type(&self) -> Option<FileCompressionType>;

/// Infer the common schema of the provided objects. The objects will usually
/// be analysed up to a given number of records or files (as specified in the
/// format config) then give the estimated common schema. This might fail if
Expand Down