Skip to content

Commit

Permalink
paralellize parquet (#7483)
Browse files Browse the repository at this point in the history
  • Loading branch information
devinjdangelo authored Sep 7, 2023
1 parent 63e452a commit 3a52ee1
Showing 1 changed file with 44 additions and 14 deletions.
58 changes: 44 additions & 14 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::any::Any;
use std::fmt;
use std::fmt::Debug;
use std::sync::Arc;
use tokio::task::JoinSet;

use arrow::datatypes::SchemaRef;
use arrow::datatypes::{Fields, Schema};
Expand Down Expand Up @@ -719,22 +720,51 @@ impl DataSink for ParquetSink {
}

let mut row_count = 0;
// TODO parallelize serialization accross partitions and batches within partitions
// see: https://github.com/apache/arrow-datafusion/issues/7079
for (part_idx, data_stream) in data.iter_mut().enumerate().take(num_partitions) {
let idx = match self.config.single_file_output {
false => part_idx,
true => 0,
};
while let Some(batch) = data_stream.next().await.transpose()? {
row_count += batch.num_rows();
// TODO cleanup all multipart writes when any encounters an error
writers[idx].write(&batch).await?;

match self.config.single_file_output {
false => {
let mut join_set: JoinSet<Result<usize, DataFusionError>> =
JoinSet::new();
for (mut data_stream, mut writer) in
data.into_iter().zip(writers.into_iter())
{
join_set.spawn(async move {
let mut cnt = 0;
while let Some(batch) = data_stream.next().await.transpose()? {
cnt += batch.num_rows();
writer.write(&batch).await?;
}
writer.close().await?;
Ok(cnt)
});
}
while let Some(result) = join_set.join_next().await {
match result {
Ok(res) => {
row_count += res?;
} // propagate DataFusion error
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
unreachable!();
}
}
}
}
}
}
true => {
let mut writer = writers.remove(0);
for data_stream in data.iter_mut() {
while let Some(batch) = data_stream.next().await.transpose()? {
row_count += batch.num_rows();
// TODO cleanup all multipart writes when any encounters an error
writer.write(&batch).await?;
}
}

for writer in writers {
writer.close().await?;
writer.close().await?;
}
}

Ok(row_count as u64)
Expand Down

0 comments on commit 3a52ee1

Please sign in to comment.