Skip to content

Commit

Permalink
Provide access to inner Write for parquet writers
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Mar 4, 2024
1 parent a02ceba commit 96de92b
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 68 deletions.
23 changes: 21 additions & 2 deletions parquet/src/arrow/arrow_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,35 @@ impl<W: Write + Send> ArrowWriter<W> {
self.writer.append_key_value_metadata(kv_metadata)
}

/// Returns a reference to the underlying writer.
pub fn inner(&self) -> &W {
self.writer.inner()
}

/// Returns a mutable reference to the underlying writer.
///
/// It is inadvisable to directly write to the underlying writer.
pub fn inner_mut(&mut self) -> &mut W {
self.writer.inner_mut()
}

/// Flushes any outstanding data and returns the underlying writer.
pub fn into_inner(mut self) -> Result<W> {
self.flush()?;
self.writer.into_inner()
}

/// Close and finalize the underlying Parquet writer
pub fn close(mut self) -> Result<crate::format::FileMetaData> {
///
/// Unlike [`Self::close`] this does not consume self
pub fn finish(&mut self) -> Result<crate::format::FileMetaData> {
self.flush()?;
self.writer.close()
self.writer.finish()
}

/// Close and finalize the underlying Parquet writer
pub fn close(mut self) -> Result<crate::format::FileMetaData> {
self.finish()
}
}

Expand Down
81 changes: 18 additions & 63 deletions parquet/src/arrow/async_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@
//! # }
//! ```

use std::{io::Write, sync::Arc};

use crate::{
arrow::arrow_writer::ArrowWriterOptions,
arrow::ArrowWriter,
Expand All @@ -71,22 +69,19 @@ use tokio::io::{AsyncWrite, AsyncWriteExt};
/// buffer's threshold is exceeded.
pub struct AsyncArrowWriter<W> {
/// Underlying sync writer
sync_writer: ArrowWriter<SharedBuffer>,
sync_writer: ArrowWriter<Vec<u8>>,

/// Async writer provided by caller
async_writer: W,

/// The inner buffer shared by the `sync_writer` and the `async_writer`
shared_buffer: SharedBuffer,

/// Trigger forced flushing once buffer size reaches this value
buffer_size: usize,
}

impl<W: AsyncWrite + Unpin + Send> AsyncArrowWriter<W> {
/// Try to create a new Async Arrow Writer.
///
/// `buffer_size` determines the number of bytes to buffer before flushing
/// `buffer_size` determines the minimum number of bytes to buffer before flushing
/// to the underlying [`AsyncWrite`]
///
/// The intermediate buffer will automatically be resized if necessary
Expand All @@ -105,7 +100,7 @@ impl<W: AsyncWrite + Unpin + Send> AsyncArrowWriter<W> {

/// Try to create a new Async Arrow Writer with [`ArrowWriterOptions`].
///
/// `buffer_size` determines the number of bytes to buffer before flushing
/// `buffer_size` determines the minimum number of bytes to buffer before flushing
/// to the underlying [`AsyncWrite`]
///
/// The intermediate buffer will automatically be resized if necessary
Expand All @@ -118,14 +113,15 @@ impl<W: AsyncWrite + Unpin + Send> AsyncArrowWriter<W> {
buffer_size: usize,
options: ArrowWriterOptions,
) -> Result<Self> {
let shared_buffer = SharedBuffer::new(buffer_size);
let sync_writer =
ArrowWriter::try_new_with_options(shared_buffer.clone(), arrow_schema, options)?;
let sync_writer = ArrowWriter::try_new_with_options(
Vec::with_capacity(buffer_size),
arrow_schema,
options,
)?;

Ok(Self {
sync_writer,
async_writer: writer,
shared_buffer,
buffer_size,
})
}
Expand Down Expand Up @@ -156,18 +152,13 @@ impl<W: AsyncWrite + Unpin + Send> AsyncArrowWriter<W> {
/// checked and flush if at least half full
pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> {
self.sync_writer.write(batch)?;
Self::try_flush(
&mut self.shared_buffer,
&mut self.async_writer,
self.buffer_size,
)
.await
self.try_flush(false).await
}

/// Flushes all buffered rows into a new row group
pub async fn flush(&mut self) -> Result<()> {
self.sync_writer.flush()?;
Self::try_flush(&mut self.shared_buffer, &mut self.async_writer, 0).await?;
self.try_flush(false).await?;

Ok(())
}
Expand All @@ -183,34 +174,29 @@ impl<W: AsyncWrite + Unpin + Send> AsyncArrowWriter<W> {
///
/// All the data in the inner buffer will be force flushed.
pub async fn close(mut self) -> Result<FileMetaData> {
let metadata = self.sync_writer.close()?;
let metadata = self.sync_writer.finish()?;

// Force to flush the remaining data.
Self::try_flush(&mut self.shared_buffer, &mut self.async_writer, 0).await?;
self.try_flush(true).await?;
self.async_writer.shutdown().await?;

Ok(metadata)
}

/// Flush the data in the [`SharedBuffer`] into the `async_writer` if its size
/// exceeds the threshold.
async fn try_flush(
shared_buffer: &mut SharedBuffer,
async_writer: &mut W,
buffer_size: usize,
) -> Result<()> {
let mut buffer = shared_buffer.buffer.try_lock().unwrap();
if buffer.is_empty() || buffer.len() < buffer_size {
/// Flush the buffered data into the `async_writer`
async fn try_flush(&mut self, force: bool) -> Result<()> {
let buffer = self.sync_writer.inner_mut();
if !force && (buffer.is_empty() || buffer.len() < self.buffer_size) {
// no need to flush
return Ok(());
}

async_writer
self.async_writer
.write_all(buffer.as_slice())
.await
.map_err(|e| ParquetError::External(Box::new(e)))?;

async_writer
self.async_writer
.flush()
.await
.map_err(|e| ParquetError::External(Box::new(e)))?;
Expand All @@ -222,37 +208,6 @@ impl<W: AsyncWrite + Unpin + Send> AsyncArrowWriter<W> {
}
}

/// A buffer with interior mutability shared by the [`ArrowWriter`] and
/// [`AsyncArrowWriter`].
#[derive(Clone)]
struct SharedBuffer {
/// The inner buffer for reading and writing
///
/// The lock is used to obtain internal mutability, so no worry about the
/// lock contention.
buffer: Arc<futures::lock::Mutex<Vec<u8>>>,
}

impl SharedBuffer {
pub fn new(capacity: usize) -> Self {
Self {
buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))),
}
}
}

impl Write for SharedBuffer {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut buffer = self.buffer.try_lock().unwrap();
Write::write(&mut *buffer, buf)
}

fn flush(&mut self) -> std::io::Result<()> {
let mut buffer = self.buffer.try_lock().unwrap();
Write::flush(&mut *buffer)
}
}

#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Field, Schema};
Expand Down
47 changes: 44 additions & 3 deletions parquet/src/file/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ impl<W: Write> TrackedWrite<W> {
self.bytes_written
}

/// Returns a reference to the underlying writer.
pub fn inner(&self) -> &W {
self.inner.get_ref()
}

/// Returns a mutable reference to the underlying writer.
///
/// It is inadvisable to directly write to the underlying writer.
pub fn inner_mut(&mut self) -> &mut W {
self.inner.get_mut()
}

/// Returns the underlying writer.
pub fn into_inner(self) -> Result<W> {
self.inner.into_inner().map_err(|err| {
Expand Down Expand Up @@ -137,6 +149,7 @@ pub struct SerializedFileWriter<W: Write> {
row_group_index: usize,
// kv_metadatas will be appended to `props` when `write_metadata`
kv_metadatas: Vec<KeyValue>,
finished: bool,
}

impl<W: Write> Debug for SerializedFileWriter<W> {
Expand Down Expand Up @@ -167,6 +180,7 @@ impl<W: Write + Send> SerializedFileWriter<W> {
offset_indexes: Vec::new(),
row_group_index: 0,
kv_metadatas: Vec::new(),
finished: false,
})
}

Expand Down Expand Up @@ -210,13 +224,18 @@ impl<W: Write + Send> SerializedFileWriter<W> {
&self.row_groups
}

/// Closes and finalises file writer, returning the file metadata.
pub fn close(mut self) -> Result<parquet::FileMetaData> {
pub fn finish(&mut self) -> Result<parquet::FileMetaData> {
self.assert_previous_writer_closed()?;
let metadata = self.write_metadata()?;
self.buf.flush()?;
Ok(metadata)
}

/// Closes and finalises file writer, returning the file metadata.
pub fn close(mut self) -> Result<parquet::FileMetaData> {
self.finish()
}

/// Writes magic bytes at the beginning of the file.
fn start_file(buf: &mut TrackedWrite<W>) -> Result<()> {
buf.write_all(&PARQUET_MAGIC)?;
Expand Down Expand Up @@ -303,6 +322,7 @@ impl<W: Write + Send> SerializedFileWriter<W> {

/// Assembles and writes metadata at the end of the file.
fn write_metadata(&mut self) -> Result<parquet::FileMetaData> {
self.finished = true;
let num_rows = self.row_groups.iter().map(|x| x.num_rows()).sum();

let mut row_groups = self
Expand Down Expand Up @@ -366,6 +386,10 @@ impl<W: Write + Send> SerializedFileWriter<W> {

#[inline]
fn assert_previous_writer_closed(&self) -> Result<()> {
if self.finished {
return Err(general_err!("SerializedFileWriter already finished"));
}

if self.row_group_index != self.row_groups.len() {
Err(general_err!("Previous row group writer was not closed"))
} else {
Expand All @@ -387,6 +411,18 @@ impl<W: Write + Send> SerializedFileWriter<W> {
&self.props
}

/// Returns a reference to the underlying writer.
pub fn inner(&self) -> &W {
self.buf.inner()
}

/// Returns a mutable reference to the underlying writer.
///
/// It is inadvisable to directly write to the underlying writer.
pub fn inner_mut(&mut self) -> &mut W {
self.buf.inner_mut()
}

/// Writes the file footer and returns the underlying writer.
pub fn into_inner(mut self) -> Result<W> {
self.assert_previous_writer_closed()?;
Expand Down Expand Up @@ -1755,7 +1791,7 @@ mod tests {
b_writer.close().unwrap();
row_group_writer.close().unwrap();

let metadata = file_writer.close().unwrap();
let metadata = file_writer.finish().unwrap();
assert_eq!(metadata.row_groups.len(), 1);
let row_group = &metadata.row_groups[0];
assert_eq!(row_group.columns.len(), 2);
Expand All @@ -1766,6 +1802,11 @@ mod tests {
assert!(row_group.columns[1].offset_index_offset.is_some());
assert!(row_group.columns[1].column_index_offset.is_none());

let err = file_writer.next_row_group().err().unwrap().to_string();
assert_eq!(err, "Parquet error: SerializedFileWriter already finished");

drop(file_writer);

let options = ReadOptionsBuilder::new().with_page_index().build();
let reader = SerializedFileReader::new_with_options(Bytes::from(file), options).unwrap();

Expand Down

0 comments on commit 96de92b

Please sign in to comment.