Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: async stream of Arrow record batches from Parquet file #258

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion arro3-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ pyo3-async-runtimes = { workspace = true, features = [
"tokio-runtime",
], optional = true }
pyo3-file = { workspace = true }
pyo3-object_store = { git = "https://github.com/developmentseed/object-store-rs", rev = "922b58ff784271345ce80342cf4cd6cddce61adf", optional = true }
pyo3-object_store = { git = "https://github.com/developmentseed/obstore", rev = "f0dad90f1e5e157760335d1ccb4045e1f3b4f194", optional = true }
thiserror = { workspace = true }
tokio = "1.41.1"
26 changes: 23 additions & 3 deletions arro3-io/python/arro3/io/_io.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import IO, Literal, Sequence
from typing import IO, Literal, Self, Sequence

# Note: importing with
# `from arro3.core import Array`
Expand Down Expand Up @@ -267,6 +267,24 @@ ParquetEncoding = Literal[
]
"""Allowed Parquet encodings."""

class ParquetRecordBatchStream:
"""
A stream of [RecordBatch][core.RecordBatch] that can be polled in a sync or
async fashion.
"""

def __aiter__(self) -> Self:
"""Return `Self` as an async iterator."""

def __iter__(self) -> Self:
"""Return `Self` as an async iterator."""

async def collect_async(self) -> core.Table:
"""Collect all remaining batches in the stream into a table."""

async def __anext__(self) -> core.RecordBatch:
"""Return the next record batch in the stream."""

def read_parquet(file: IO[bytes] | Path | str) -> core.RecordBatchReader:
"""Read a Parquet file to an Arrow RecordBatchReader

Expand All @@ -277,8 +295,10 @@ def read_parquet(file: IO[bytes] | Path | str) -> core.RecordBatchReader:
The loaded Arrow data.
"""

async def read_parquet_async(path: str, *, store: ObjectStore) -> core.Table:
"""Read a Parquet file to an Arrow Table in an async fashion
async def read_parquet_async(
path: str, *, store: ObjectStore
) -> ParquetRecordBatchStream:
"""Create an async stream of Arrow record batches from a Parquet file.

Args:
file: The path to the Parquet file in the given store
Expand Down
1 change: 1 addition & 0 deletions arro3-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ fn _io(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(___version))?;

pyo3_object_store::register_store_module(py, m, "arro3.io")?;
pyo3_object_store::register_exceptions_module(py, m, "arro3.io")?;

m.add_wrapped(wrap_pyfunction!(csv::infer_csv_schema))?;
m.add_wrapped(wrap_pyfunction!(csv::read_csv))?;
Expand Down
105 changes: 91 additions & 14 deletions arro3-io/src/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,26 @@ use std::str::FromStr;
use std::sync::Arc;

use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::SchemaRef;
use futures::StreamExt;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use parquet::arrow::arrow_writer::ArrowWriterOptions;
use parquet::arrow::async_reader::ParquetObjectReader;
use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStream};
use parquet::arrow::ArrowWriter;
use parquet::arrow::ParquetRecordBatchStreamBuilder;
use parquet::basic::{Compression, Encoding};
use parquet::file::properties::{WriterProperties, WriterVersion};
use parquet::format::KeyValue;
use parquet::schema::types::ColumnPath;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3_arrow::error::PyArrowResult;
use pyo3_arrow::input::AnyRecordBatch;
use pyo3_arrow::{PyRecordBatchReader, PyTable};
use pyo3_arrow::{PyRecordBatch, PyRecordBatchReader, PyTable};
use pyo3_object_store::PyObjectStore;
use tokio::sync::Mutex;

use crate::error::Arro3IoResult;
use crate::error::{Arro3IoError, Arro3IoResult};
use crate::utils::{FileReader, FileWriter};

#[pyfunction]
Expand Down Expand Up @@ -49,21 +53,92 @@ pub fn read_parquet_async(
py: Python,
path: String,
store: PyObjectStore,
) -> PyArrowResult<PyObject> {
let fut = pyo3_async_runtimes::tokio::future_into_py(py, async move {
) -> PyResult<Bound<PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(read_parquet_async_inner(store.into_inner(), path).await?)
})?;
})
}

struct PyRecordBatchWrapper(PyRecordBatch);

impl IntoPy<PyObject> for PyRecordBatchWrapper {
fn into_py(self, py: Python<'_>) -> PyObject {
self.0.to_arro3(py).unwrap()
}
}

struct PyTableWrapper(PyTable);

impl IntoPy<PyObject> for PyTableWrapper {
fn into_py(self, py: Python<'_>) -> PyObject {
self.0.to_arro3(py).unwrap()
}
}

#[pyclass(name = "ParquetRecordBatchStream")]
struct PyParquetRecordBatchStream {
stream: Arc<Mutex<ParquetRecordBatchStream<ParquetObjectReader>>>,
schema: SchemaRef,
}

#[pymethods]
impl PyParquetRecordBatchStream {
fn __aiter__(slf: Py<Self>) -> Py<Self> {
slf
}

fn __anext__<'py>(&'py mut self, py: Python<'py>) -> PyResult<Bound<PyAny>> {
let stream = self.stream.clone();
pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, false))
}

fn collect_async<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<PyAny>> {
let stream = self.stream.clone();
pyo3_async_runtimes::tokio::future_into_py(py, collect_stream(stream, self.schema.clone()))
}
}

Ok(fut.into())
async fn next_stream(
stream: Arc<Mutex<ParquetRecordBatchStream<ParquetObjectReader>>>,
sync: bool,
) -> PyResult<PyRecordBatchWrapper> {
let mut stream = stream.lock().await;
match stream.next().await {
Some(Ok(batch)) => Ok(PyRecordBatchWrapper(PyRecordBatch::new(batch))),
Some(Err(err)) => Err(Arro3IoError::ParquetError(err).into()),
None => {
// Depending on whether the iteration is sync or not, we raise either a
// StopIteration or a StopAsyncIteration
if sync {
Err(PyStopIteration::new_err("stream exhausted"))
} else {
Err(PyStopAsyncIteration::new_err("stream exhausted"))
}
}
}
}

async fn collect_stream(
stream: Arc<Mutex<ParquetRecordBatchStream<ParquetObjectReader>>>,
schema: SchemaRef,
) -> PyResult<PyTableWrapper> {
let mut stream = stream.lock().await;
let mut batches: Vec<_> = vec![];
loop {
match stream.next().await {
Some(Ok(batch)) => {
batches.push(batch);
}
Some(Err(err)) => return Err(Arro3IoError::ParquetError(err).into()),
None => return Ok(PyTableWrapper(PyTable::try_new(batches, schema)?)),
};
}
}

async fn read_parquet_async_inner(
store: Arc<dyn object_store::ObjectStore>,
path: String,
) -> Arro3IoResult<PyTable> {
use futures::TryStreamExt;
use parquet::arrow::ParquetRecordBatchStreamBuilder;

) -> Arro3IoResult<PyParquetRecordBatchStream> {
let meta = store.head(&path.into()).await?;

let object_reader = ParquetObjectReader::new(store, meta);
Expand All @@ -74,8 +149,10 @@ async fn read_parquet_async_inner(

let arrow_schema = Arc::new(reader.schema().as_ref().clone().with_metadata(metadata));

let batches = reader.try_collect::<Vec<_>>().await?;
Ok(PyTable::try_new(batches, arrow_schema)?)
Ok(PyParquetRecordBatchStream {
stream: Arc::new(Mutex::new(reader)),
schema: arrow_schema,
})
}

pub(crate) struct PyWriterVersion(WriterVersion);
Expand Down
28 changes: 27 additions & 1 deletion tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import pyarrow as pa
import pyarrow.parquet as pq
from arro3.io import read_parquet, write_parquet
from arro3.io import read_parquet, read_parquet_async, write_parquet
from arro3.io.store import HTTPStore


def test_parquet_round_trip():
Expand Down Expand Up @@ -42,3 +43,28 @@ def test_copy_parquet_kv_metadata():

reader = read_parquet("test.parquet")
assert reader.schema.metadata[b"hello"] == b"world"


async def test_stream_parquet():
from time import time

t0 = time()
url = "https://overturemaps-us-west-2.s3.amazonaws.com/release/2024-03-12-alpha.0/theme=buildings/type=building/part-00217-4dfc75cd-2680-4d52-b5e0-f4cc9f36b267-c000.zstd.parquet"
store = HTTPStore.from_url(url)
stream = await read_parquet_async("", store=store)
t1 = time()
first = await stream.__anext__()
t2 = time()

print(t1 - t0)
print(t2 - t1)

test = await stream.collect_async()
len(test)
async for batch in stream:
break

batch.num_rows
x = await stream.__anext__()

pass
Loading