Skip to content

Commit 2741c60

Browse files
theirixalamb
andauthored
Use compression type in CSV file suffices (#16609)
* Use compression type in file suffices - Add FileFormat::compression_type method - Specify meaningful values for CSV only - Use compression type as a part of extension for files * Add CSV tests * Add glob dep, use env logging * Use a glob pattern with compression suffix for TableProviderFactory * Conform to clippy standards --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 01698cb commit 2741c60

File tree

13 files changed

+234
-4
lines changed

13 files changed

+234
-4
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-examples/examples/custom_file_format.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ impl FileFormat for TSVFileFormat {
8181
}
8282
}
8383

84+
fn compression_type(&self) -> Option<FileCompressionType> {
85+
None
86+
}
87+
8488
async fn infer_schema(
8589
&self,
8690
state: &dyn Session,

datafusion-examples/examples/dataframe.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ use tempfile::tempdir;
5959
/// * [query_to_date]: execute queries against parquet files
6060
#[tokio::main]
6161
async fn main() -> Result<()> {
62+
env_logger::init();
6263
// The SessionContext is the main high level API for interacting with DataFusion
6364
let ctx = SessionContext::new();
6465
read_parquet(&ctx).await?;

datafusion/core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ datafusion-macros = { workspace = true }
154154
datafusion-physical-optimizer = { workspace = true }
155155
doc-comment = { workspace = true }
156156
env_logger = { workspace = true }
157+
glob = { version = "0.3.0" }
157158
insta = { workspace = true }
158159
paste = "^1.0"
159160
rand = { workspace = true, features = ["small_rng"] }

datafusion/core/src/datasource/file_format/arrow.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ impl FileFormat for ArrowFormat {
134134
}
135135
}
136136

137+
fn compression_type(&self) -> Option<FileCompressionType> {
138+
None
139+
}
140+
137141
async fn infer_schema(
138142
&self,
139143
_state: &dyn Session,

datafusion/core/src/datasource/file_format/csv.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ mod tests {
5656
use async_trait::async_trait;
5757
use bytes::Bytes;
5858
use chrono::DateTime;
59+
use datafusion_common::parsers::CompressionTypeVariant;
5960
use futures::stream::BoxStream;
6061
use futures::StreamExt;
6162
use insta::assert_snapshot;
@@ -877,6 +878,86 @@ mod tests {
877878
Ok(())
878879
}
879880

881+
#[tokio::test]
882+
async fn test_csv_extension_compressed() -> Result<()> {
883+
// Write compressed CSV files
884+
// Expect: under the directory, a file is created with ".csv.gz" extension
885+
let ctx = SessionContext::new();
886+
887+
let df = ctx
888+
.read_csv(
889+
&format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
890+
CsvReadOptions::default().has_header(true),
891+
)
892+
.await?;
893+
894+
let tmp_dir = tempfile::TempDir::new().unwrap();
895+
let path = format!("{}", tmp_dir.path().to_string_lossy());
896+
897+
let cfg1 = crate::dataframe::DataFrameWriteOptions::new();
898+
let cfg2 = CsvOptions::default()
899+
.with_has_header(true)
900+
.with_compression(CompressionTypeVariant::GZIP);
901+
902+
df.write_csv(&path, cfg1, Some(cfg2)).await?;
903+
assert!(std::path::Path::new(&path).exists());
904+
905+
let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect();
906+
assert_eq!(files.len(), 1);
907+
assert!(files
908+
.last()
909+
.unwrap()
910+
.as_ref()
911+
.unwrap()
912+
.path()
913+
.file_name()
914+
.unwrap()
915+
.to_str()
916+
.unwrap()
917+
.ends_with(".csv.gz"));
918+
919+
Ok(())
920+
}
921+
922+
#[tokio::test]
923+
async fn test_csv_extension_uncompressed() -> Result<()> {
924+
// Write plain uncompressed CSV files
925+
// Expect: under the directory, a file is created with ".csv" extension
926+
let ctx = SessionContext::new();
927+
928+
let df = ctx
929+
.read_csv(
930+
&format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
931+
CsvReadOptions::default().has_header(true),
932+
)
933+
.await?;
934+
935+
let tmp_dir = tempfile::TempDir::new().unwrap();
936+
let path = format!("{}", tmp_dir.path().to_string_lossy());
937+
938+
let cfg1 = crate::dataframe::DataFrameWriteOptions::new();
939+
let cfg2 = CsvOptions::default().with_has_header(true);
940+
941+
df.write_csv(&path, cfg1, Some(cfg2)).await?;
942+
assert!(std::path::Path::new(&path).exists());
943+
944+
let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect();
945+
assert_eq!(files.len(), 1);
946+
assert!(files
947+
.last()
948+
.unwrap()
949+
.as_ref()
950+
.unwrap()
951+
.path()
952+
.file_name()
953+
.unwrap()
954+
.to_str()
955+
.unwrap()
956+
.ends_with(".csv"));
957+
958+
Ok(())
959+
}
960+
880961
/// Read multiple empty csv files
881962
///
882963
/// all_empty

datafusion/core/src/datasource/listing_table_factory.rs

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,21 @@ impl TableProviderFactory for ListingTableFactory {
128128
// if the folder then rewrite a file path as 'path/*.parquet'
129129
// to only read the files the reader can understand
130130
if table_path.is_folder() && table_path.get_glob().is_none() {
131-
table_path = table_path.with_glob(
132-
format!("*.{}", cmd.file_type.to_lowercase()).as_ref(),
133-
)?;
131+
// Since there are no files yet to infer an actual extension,
132+
// derive the pattern based on compression type.
133+
// So for gzipped CSV the pattern is `*.csv.gz`
134+
let glob = match options.format.compression_type() {
135+
Some(compression) => {
136+
match options.format.get_ext_with_compression(&compression) {
137+
// Use glob based on `FileFormat` extension
138+
Ok(ext) => format!("*.{ext}"),
139+
// Fallback to `file_type`, if not supported by `FileFormat`
140+
Err(_) => format!("*.{}", cmd.file_type.to_lowercase()),
141+
}
142+
}
143+
None => format!("*.{}", cmd.file_type.to_lowercase()),
144+
};
145+
table_path = table_path.with_glob(glob.as_ref())?;
134146
}
135147
let schema = options.infer_schema(session_state, &table_path).await?;
136148
let df_schema = Arc::clone(&schema).to_dfschema()?;
@@ -175,13 +187,15 @@ fn get_extension(path: &str) -> String {
175187

176188
#[cfg(test)]
177189
mod tests {
190+
use glob::Pattern;
178191
use std::collections::HashMap;
179192

180193
use super::*;
181194
use crate::{
182195
datasource::file_format::csv::CsvFormat, execution::context::SessionContext,
183196
};
184197

198+
use datafusion_common::parsers::CompressionTypeVariant;
185199
use datafusion_common::{Constraints, DFSchema, TableReference};
186200

187201
#[tokio::test]
@@ -264,4 +278,101 @@ mod tests {
264278
let listing_options = listing_table.options();
265279
assert_eq!(".tbl", listing_options.file_extension);
266280
}
281+
282+
/// Validates that CreateExternalTable with compression
283+
/// searches for gzipped files in a directory location
284+
#[tokio::test]
285+
async fn test_create_using_folder_with_compression() {
286+
let dir = tempfile::tempdir().unwrap();
287+
288+
let factory = ListingTableFactory::new();
289+
let context = SessionContext::new();
290+
let state = context.state();
291+
let name = TableReference::bare("foo");
292+
293+
let mut options = HashMap::new();
294+
options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
295+
options.insert("format.has_header".into(), "true".into());
296+
options.insert("format.compression".into(), "gzip".into());
297+
let cmd = CreateExternalTable {
298+
name,
299+
location: dir.path().to_str().unwrap().to_string(),
300+
file_type: "csv".to_string(),
301+
schema: Arc::new(DFSchema::empty()),
302+
table_partition_cols: vec![],
303+
if_not_exists: false,
304+
temporary: false,
305+
definition: None,
306+
order_exprs: vec![],
307+
unbounded: false,
308+
options,
309+
constraints: Constraints::default(),
310+
column_defaults: HashMap::new(),
311+
};
312+
let table_provider = factory.create(&state, &cmd).await.unwrap();
313+
let listing_table = table_provider
314+
.as_any()
315+
.downcast_ref::<ListingTable>()
316+
.unwrap();
317+
318+
// Verify compression is used
319+
let format = listing_table.options().format.clone();
320+
let csv_format = format.as_any().downcast_ref::<CsvFormat>().unwrap();
321+
let csv_options = csv_format.options().clone();
322+
assert_eq!(csv_options.compression, CompressionTypeVariant::GZIP);
323+
324+
let listing_options = listing_table.options();
325+
assert_eq!("", listing_options.file_extension);
326+
// Glob pattern is set to search for gzipped files
327+
let table_path = listing_table.table_paths().first().unwrap();
328+
assert_eq!(
329+
table_path.get_glob().clone().unwrap(),
330+
Pattern::new("*.csv.gz").unwrap()
331+
);
332+
}
333+
334+
/// Validates that CreateExternalTable without compression
335+
/// searches for normal files in a directory location
336+
#[tokio::test]
337+
async fn test_create_using_folder_without_compression() {
338+
let dir = tempfile::tempdir().unwrap();
339+
340+
let factory = ListingTableFactory::new();
341+
let context = SessionContext::new();
342+
let state = context.state();
343+
let name = TableReference::bare("foo");
344+
345+
let mut options = HashMap::new();
346+
options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
347+
options.insert("format.has_header".into(), "true".into());
348+
let cmd = CreateExternalTable {
349+
name,
350+
location: dir.path().to_str().unwrap().to_string(),
351+
file_type: "csv".to_string(),
352+
schema: Arc::new(DFSchema::empty()),
353+
table_partition_cols: vec![],
354+
if_not_exists: false,
355+
temporary: false,
356+
definition: None,
357+
order_exprs: vec![],
358+
unbounded: false,
359+
options,
360+
constraints: Constraints::default(),
361+
column_defaults: HashMap::new(),
362+
};
363+
let table_provider = factory.create(&state, &cmd).await.unwrap();
364+
let listing_table = table_provider
365+
.as_any()
366+
.downcast_ref::<ListingTable>()
367+
.unwrap();
368+
369+
let listing_options = listing_table.options();
370+
assert_eq!("", listing_options.file_extension);
371+
// Glob pattern is set to search for gzipped files
372+
let table_path = listing_table.table_paths().first().unwrap();
373+
assert_eq!(
374+
table_path.get_glob().clone().unwrap(),
375+
Pattern::new("*.csv").unwrap()
376+
);
377+
}
267378
}

datafusion/core/src/physical_planner.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,14 @@ impl DefaultPhysicalPlanner {
533533
let sink_format = file_type_to_format(file_type)?
534534
.create(session_state, source_option_tuples)?;
535535

536+
// Determine extension based on format extension and compression
537+
let file_extension = match sink_format.compression_type() {
538+
Some(compression_type) => sink_format
539+
.get_ext_with_compression(&compression_type)
540+
.unwrap_or_else(|_| sink_format.get_ext()),
541+
None => sink_format.get_ext(),
542+
};
543+
536544
// Set file sink related options
537545
let config = FileSinkConfig {
538546
original_url,
@@ -543,7 +551,7 @@ impl DefaultPhysicalPlanner {
543551
table_partition_cols,
544552
insert_op: InsertOp::Append,
545553
keep_partition_by_columns,
546-
file_extension: sink_format.get_ext(),
554+
file_extension,
547555
};
548556

549557
sink_format

datafusion/datasource-avro/src/file_format.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ impl FileFormat for AvroFormat {
110110
}
111111
}
112112

113+
fn compression_type(&self) -> Option<FileCompressionType> {
114+
None
115+
}
116+
113117
async fn infer_schema(
114118
&self,
115119
_state: &dyn Session,

datafusion/datasource-csv/src/file_format.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,10 @@ impl FileFormat for CsvFormat {
358358
Ok(format!("{}{}", ext, file_compression_type.get_ext()))
359359
}
360360

361+
fn compression_type(&self) -> Option<FileCompressionType> {
362+
Some(self.options.compression.into())
363+
}
364+
361365
async fn infer_schema(
362366
&self,
363367
state: &dyn Session,

0 commit comments

Comments
 (0)