|
18 | 18 | //! Execution plan for reading CSV files |
19 | 19 |
|
20 | 20 | use crate::error::{DataFusionError, Result}; |
| 21 | +use crate::execution::context::ExecutionContext; |
21 | 22 | use crate::physical_plan::expressions::PhysicalSortExpr; |
22 | 23 | use crate::physical_plan::{ |
23 | 24 | DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, |
24 | 25 | }; |
25 | 26 |
|
| 27 | +use crate::execution::runtime_env::RuntimeEnv; |
26 | 28 | use arrow::csv; |
27 | 29 | use arrow::datatypes::SchemaRef; |
| 30 | +use async_trait::async_trait; |
| 31 | +use futures::{StreamExt, TryStreamExt}; |
28 | 32 | use std::any::Any; |
| 33 | +use std::fs; |
| 34 | +use std::path::Path; |
29 | 35 | use std::sync::Arc; |
30 | | - |
31 | | -use crate::execution::runtime_env::RuntimeEnv; |
32 | | -use async_trait::async_trait; |
| 36 | +use tokio::task::{self, JoinHandle}; |
33 | 37 |
|
34 | 38 | use super::file_stream::{BatchIter, FileStream}; |
35 | 39 | use super::FileScanConfig; |
@@ -176,16 +180,59 @@ impl ExecutionPlan for CsvExec { |
176 | 180 | } |
177 | 181 | } |
178 | 182 |
|
| 183 | +pub async fn plan_to_csv( |
| 184 | + context: &ExecutionContext, |
| 185 | + plan: Arc<dyn ExecutionPlan>, |
| 186 | + path: impl AsRef<str>, |
| 187 | +) -> Result<()> { |
| 188 | + let path = path.as_ref(); |
| 189 | + // create directory to contain the CSV files (one per partition) |
| 190 | + let fs_path = Path::new(path); |
| 191 | + let runtime = context.runtime_env(); |
| 192 | + match fs::create_dir(fs_path) { |
| 193 | + Ok(()) => { |
| 194 | + let mut tasks = vec![]; |
| 195 | + for i in 0..plan.output_partitioning().partition_count() { |
| 196 | + let plan = plan.clone(); |
| 197 | + let filename = format!("part-{}.csv", i); |
| 198 | + let path = fs_path.join(&filename); |
| 199 | + let file = fs::File::create(path)?; |
| 200 | + let mut writer = csv::Writer::new(file); |
| 201 | + let stream = plan.execute(i, runtime.clone()).await?; |
| 202 | + let handle: JoinHandle<Result<()>> = task::spawn(async move { |
| 203 | + stream |
| 204 | + .map(|batch| writer.write(&batch?)) |
| 205 | + .try_collect() |
| 206 | + .await |
| 207 | + .map_err(DataFusionError::from) |
| 208 | + }); |
| 209 | + tasks.push(handle); |
| 210 | + } |
| 211 | + futures::future::join_all(tasks).await; |
| 212 | + Ok(()) |
| 213 | + } |
| 214 | + Err(e) => Err(DataFusionError::Execution(format!( |
| 215 | + "Could not create directory {}: {:?}", |
| 216 | + path, e |
| 217 | + ))), |
| 218 | + } |
| 219 | +} |
| 220 | + |
179 | 221 | #[cfg(test)] |
180 | 222 | mod tests { |
181 | 223 | use super::*; |
| 224 | + use crate::prelude::*; |
182 | 225 | use crate::test_util::aggr_test_schema_with_missing_col; |
183 | 226 | use crate::{ |
184 | 227 | datasource::object_store::local::{local_unpartitioned_file, LocalFileSystem}, |
185 | 228 | scalar::ScalarValue, |
186 | 229 | test_util::aggr_test_schema, |
187 | 230 | }; |
| 231 | + use arrow::datatypes::*; |
188 | 232 | use futures::StreamExt; |
| 233 | + use std::fs::File; |
| 234 | + use std::io::Write; |
| 235 | + use tempfile::TempDir; |
189 | 236 |
|
190 | 237 | #[tokio::test] |
191 | 238 | async fn csv_exec_with_projection() -> Result<()> { |
@@ -376,4 +423,87 @@ mod tests { |
376 | 423 | crate::assert_batches_eq!(expected, &[batch.slice(0, 5)]); |
377 | 424 | Ok(()) |
378 | 425 | } |
| 426 | + |
| 427 | + /// Generate CSV partitions within the supplied directory |
| 428 | + fn populate_csv_partitions( |
| 429 | + tmp_dir: &TempDir, |
| 430 | + partition_count: usize, |
| 431 | + file_extension: &str, |
| 432 | + ) -> Result<SchemaRef> { |
| 433 | + // define schema for data source (csv file) |
| 434 | + let schema = Arc::new(Schema::new(vec![ |
| 435 | + Field::new("c1", DataType::UInt32, false), |
| 436 | + Field::new("c2", DataType::UInt64, false), |
| 437 | + Field::new("c3", DataType::Boolean, false), |
| 438 | + ])); |
| 439 | + |
| 440 | + // generate a partitioned file |
| 441 | + for partition in 0..partition_count { |
| 442 | + let filename = format!("partition-{}.{}", partition, file_extension); |
| 443 | + let file_path = tmp_dir.path().join(&filename); |
| 444 | + let mut file = File::create(file_path)?; |
| 445 | + |
| 446 | + // generate some data |
| 447 | + for i in 0..=10 { |
| 448 | + let data = format!("{},{},{}\n", partition, i, i % 2 == 0); |
| 449 | + file.write_all(data.as_bytes())?; |
| 450 | + } |
| 451 | + } |
| 452 | + |
| 453 | + Ok(schema) |
| 454 | + } |
| 455 | + |
| 456 | + #[tokio::test] |
| 457 | + async fn write_csv_results() -> Result<()> { |
| 458 | + // create partitioned input file and context |
| 459 | + let tmp_dir = TempDir::new()?; |
| 460 | + let mut ctx = ExecutionContext::with_config( |
| 461 | + ExecutionConfig::new().with_target_partitions(8), |
| 462 | + ); |
| 463 | + |
| 464 | + let schema = populate_csv_partitions(&tmp_dir, 8, ".csv")?; |
| 465 | + |
| 466 | + // register csv file with the execution context |
| 467 | + ctx.register_csv( |
| 468 | + "test", |
| 469 | + tmp_dir.path().to_str().unwrap(), |
| 470 | + CsvReadOptions::new().schema(&schema), |
| 471 | + ) |
| 472 | + .await?; |
| 473 | + |
| 474 | + // execute a simple query and write the results to CSV |
| 475 | + let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; |
| 476 | + let df = ctx.sql("SELECT c1, c2 FROM test").await?; |
| 477 | + df.write_csv(&out_dir).await?; |
| 478 | + |
| 479 | + // create a new context and verify that the results were saved to a partitioned csv file |
| 480 | + let mut ctx = ExecutionContext::new(); |
| 481 | + |
| 482 | + let schema = Arc::new(Schema::new(vec![ |
| 483 | + Field::new("c1", DataType::UInt32, false), |
| 484 | + Field::new("c2", DataType::UInt64, false), |
| 485 | + ])); |
| 486 | + |
| 487 | + // register each partition as well as the top level dir |
| 488 | + let csv_read_option = CsvReadOptions::new().schema(&schema); |
| 489 | + ctx.register_csv("part0", &format!("{}/part-0.csv", out_dir), csv_read_option) |
| 490 | + .await?; |
| 491 | + ctx.register_csv("allparts", &out_dir, csv_read_option) |
| 492 | + .await?; |
| 493 | + |
| 494 | + let part0 = ctx.sql("SELECT c1, c2 FROM part0").await?.collect().await?; |
| 495 | + let allparts = ctx |
| 496 | + .sql("SELECT c1, c2 FROM allparts") |
| 497 | + .await? |
| 498 | + .collect() |
| 499 | + .await?; |
| 500 | + |
| 501 | + let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum(); |
| 502 | + |
| 503 | + assert_eq!(part0[0].schema(), allparts[0].schema()); |
| 504 | + |
| 505 | + assert_eq!(allparts_count, 80); |
| 506 | + |
| 507 | + Ok(()) |
| 508 | + } |
379 | 509 | } |
0 commit comments