Skip to content

Commit 37b46be

Browse files
committed
Implement a Vec<RecordBatch> wrapper for pyarrow.Table convenience
CQ fixes CQ fix CQ fix Let `Table` be a combination of `Vec<RecordBatch>` and `SchemaRef` instead `cargo fmt` Overhauled `Table` definition, Added tests Add empty `Table` integration test Update `arrow-pyarrow`'s crate documentation Overhaul documentation even more Typo fix
1 parent 40300ca commit 37b46be

File tree

3 files changed

+241
-9
lines changed

3 files changed

+241
-9
lines changed

arrow-pyarrow-integration-testing/src/lib.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use arrow::compute::kernels;
3232
use arrow::datatypes::{DataType, Field, Schema};
3333
use arrow::error::ArrowError;
3434
use arrow::ffi_stream::ArrowArrayStreamReader;
35-
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, ToPyArrow};
35+
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow};
3636
use arrow::record_batch::RecordBatch;
3737

3838
fn to_py_err(err: ArrowError) -> PyErr {
@@ -140,6 +140,28 @@ fn round_trip_record_batch_reader(
140140
Ok(obj)
141141
}
142142

143+
#[pyfunction]
144+
fn round_trip_table(obj: PyArrowType<Table>) -> PyResult<PyArrowType<Table>> {
145+
Ok(obj)
146+
}
147+
148+
/// Function for testing whether a `Vec<RecordBatch>` is exportable as `pyarrow.Table`, with or
149+
/// without explicitly providing a schema
150+
#[pyfunction]
151+
#[pyo3(signature = (record_batches, *, schema=None))]
152+
pub fn build_table(
153+
record_batches: Vec<PyArrowType<RecordBatch>>,
154+
schema: Option<PyArrowType<Schema>>,
155+
) -> PyResult<PyArrowType<Table>> {
156+
Ok(PyArrowType(
157+
Table::try_new(
158+
record_batches.into_iter().map(|rb| rb.0).collect(),
159+
schema.map(|s| Arc::new(s.0)),
160+
)
161+
.map_err(to_py_err)?,
162+
))
163+
}
164+
143165
#[pyfunction]
144166
fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()> {
145167
// This makes sure we can correctly consume a RBR and return the error,
@@ -178,6 +200,8 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &Bound<PyModule>) -> PyResu
178200
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
179201
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
180202
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
203+
m.add_wrapped(wrap_pyfunction!(round_trip_table))?;
204+
m.add_wrapped(wrap_pyfunction!(build_table))?;
181205
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
182206
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
183207
Ok(())

arrow-pyarrow-integration-testing/tests/test_sql.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,80 @@ def test_table_pycapsule():
613613
assert len(table.to_batches()) == len(new_table.to_batches())
614614

615615

616+
def test_table_empty():
617+
"""
618+
Python -> Rust -> Python
619+
"""
620+
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
621+
table = pa.Table.from_batches([], schema=schema)
622+
new_table = rust.build_table([], schema=schema)
623+
624+
assert table.schema == new_table.schema
625+
assert table == new_table
626+
assert len(table.to_batches()) == len(new_table.to_batches())
627+
628+
629+
def test_table_roundtrip():
630+
"""
631+
Python -> Rust -> Python
632+
"""
633+
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
634+
batches = [
635+
pa.record_batch([[[1], [2, 42]]], schema),
636+
pa.record_batch([[None, [], [5, 6]]], schema),
637+
]
638+
table = pa.Table.from_batches(batches)
639+
new_table = rust.round_trip_table(table)
640+
641+
assert table.schema == new_table.schema
642+
assert table == new_table
643+
assert len(table.to_batches()) == len(new_table.to_batches())
644+
645+
646+
@pytest.mark.parametrize("set_schema", (True, False))
647+
def test_table_from_batches(set_schema: bool):
648+
"""
649+
Python -> Rust -> Python
650+
"""
651+
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
652+
batches = [
653+
pa.record_batch([[[1], [2, 42]]], schema),
654+
pa.record_batch([[None, [], [5, 6]]], schema),
655+
]
656+
table = pa.Table.from_batches(batches)
657+
new_table = rust.build_table(batches, schema=schema if set_schema else None)
658+
659+
assert table.schema == new_table.schema
660+
assert table == new_table
661+
assert len(table.to_batches()) == len(new_table.to_batches())
662+
663+
664+
def test_table_error_inconsistent_schema():
665+
"""
666+
Python -> Rust -> Python
667+
"""
668+
schema_1 = pa.schema([('ints', pa.list_(pa.int32()))])
669+
schema_2 = pa.schema([('floats', pa.list_(pa.float32()))])
670+
batches = [
671+
pa.record_batch([[[1], [2, 42]]], schema_1),
672+
pa.record_batch([[None, [], [5.6, 6.4]]], schema_2),
673+
]
674+
with pytest.raises(pa.ArrowException, match="Schema error: All record batches must have the same schema."):
675+
rust.build_table(batches)
676+
677+
678+
def test_table_error_no_schema():
679+
"""
680+
Python -> Rust -> Python
681+
"""
682+
batches = []
683+
with pytest.raises(
684+
pa.ArrowException,
685+
match="Schema error: If no schema is supplied explicitly, there must be at least one RecordBatch!"
686+
):
687+
rust.build_table(batches)
688+
689+
616690
def test_reject_other_classes():
617691
# Arbitrary type that is not a PyArrow type
618692
not_pyarrow = ["hello"]

arrow-pyarrow/src/lib.rs

Lines changed: 142 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,20 @@
4444
//! | `pyarrow.Array` | [ArrayData] |
4545
//! | `pyarrow.RecordBatch` | [RecordBatch] |
4646
//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) |
47+
//! | `pyarrow.Table` | [Table] (2) |
4748
//!
4849
//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
4950
//! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported
5051
//! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically
5152
//! easier to create.)
5253
//!
53-
//! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't
54-
//! have these same concepts. A chunked table is instead represented with
55-
//! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling
56-
//! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader)
57-
//! and then importing the reader as a [ArrowArrayStreamReader].
54+
//! (2) Although arrow-rs offers a [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
55+
//! convenience wrapper [Table] (which internally holds `Vec<RecordBatch>`), this is more meant for
56+
//! use cases where you already have `Vec<RecordBatch>` on the Rust side and want to export that in
57+
//! bulk as a `pyarrow.Table`. In general, it is recommended to use streaming approaches instead of
58+
//! dealing with bulk data.
59+
//! For example, a `pyarrow.Table` can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>`
60+
//! instead (since `pyarrow.Table` implements the ArrayStream PyCapsule interface).
5861
5962
use std::convert::{From, TryFrom};
6063
use std::ptr::{addr_of, addr_of_mut};
@@ -68,13 +71,13 @@ use arrow_array::{
6871
make_array,
6972
};
7073
use arrow_data::ArrayData;
71-
use arrow_schema::{ArrowError, DataType, Field, Schema};
74+
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
7275
use pyo3::exceptions::{PyTypeError, PyValueError};
7376
use pyo3::ffi::Py_uintptr_t;
74-
use pyo3::import_exception;
7577
use pyo3::prelude::*;
7678
use pyo3::pybacked::PyBackedStr;
77-
use pyo3::types::{PyCapsule, PyList, PyTuple};
79+
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
80+
use pyo3::{import_exception, intern};
7881

7982
import_exception!(pyarrow, ArrowException);
8083
/// Represents an exception raised by PyArrow.
@@ -484,6 +487,137 @@ impl IntoPyArrow for ArrowArrayStreamReader {
484487
}
485488
}
486489

490+
/// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from
491+
/// and to `pyarrow.Table`.
492+
///
493+
/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
494+
/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
495+
/// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more
496+
/// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust
497+
/// side.
498+
///
499+
/// ```ignore
500+
/// #[pyfunction]
501+
/// fn return_table(...) -> PyResult<PyArrowType<Table>> {
502+
/// let batches: Vec<RecordBatch>;
503+
/// let schema: SchemaRef;
504+
/// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
505+
/// }
506+
/// ```
507+
#[derive(Clone)]
508+
pub struct Table {
509+
record_batches: Vec<RecordBatch>,
510+
schema: SchemaRef,
511+
}
512+
513+
impl Table {
514+
pub unsafe fn new_unchecked(record_batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
515+
Self {
516+
record_batches,
517+
schema,
518+
}
519+
}
520+
521+
pub fn try_new(
522+
record_batches: Vec<RecordBatch>,
523+
schema: Option<SchemaRef>,
524+
) -> Result<Self, ArrowError> {
525+
let schema = match schema {
526+
Some(s) => s,
527+
None => {
528+
record_batches
529+
.get(0)
530+
.ok_or_else(|| ArrowError::SchemaError(
531+
"If no schema is supplied explicitly, there must be at least one RecordBatch!".to_owned()
532+
))?
533+
.schema()
534+
.clone()
535+
}
536+
};
537+
for record_batch in &record_batches {
538+
if schema != record_batch.schema() {
539+
return Err(ArrowError::SchemaError(
540+
"All record batches must have the same schema.".to_owned(),
541+
));
542+
}
543+
}
544+
Ok(Self {
545+
record_batches,
546+
schema,
547+
})
548+
}
549+
550+
pub fn record_batches(&self) -> &[RecordBatch] {
551+
&self.record_batches
552+
}
553+
554+
pub fn schema(&self) -> SchemaRef {
555+
self.schema.clone()
556+
}
557+
558+
pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
559+
(self.record_batches, self.schema)
560+
}
561+
}
562+
563+
impl TryFrom<ArrowArrayStreamReader> for Table {
564+
type Error = ArrowError;
565+
566+
fn try_from(value: ArrowArrayStreamReader) -> Result<Self, ArrowError> {
567+
let schema = value.schema();
568+
let batches = value.collect::<Result<Vec<_>, _>>()?;
569+
// We assume all batches have the same schema here.
570+
unsafe { Ok(Self::new_unchecked(batches, schema)) }
571+
}
572+
}
573+
574+
impl FromPyArrow for Table {
575+
fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
576+
let array_stream_reader: PyResult<ArrowArrayStreamReader> = {
577+
// First, try whether the object implements the Arrow ArrayStreamReader protocol directly
578+
// (which `pyarrow.Table` does) or test whether it is a RecordBatchReader.
579+
let reader_result = if let Ok(reader) = ArrowArrayStreamReader::from_pyarrow_bound(ob) {
580+
Some(reader)
581+
}
582+
// If that is not the case, test whether it has a `to_reader` method (which
583+
// `pyarrow.Table` does) whose return value implements the Arrow ArrayStreamReader
584+
// protocol or is a RecordBatchReader.
585+
else if ob.hasattr(intern!(ob.py(), "to_reader"))? {
586+
let py_reader = ob.getattr(intern!(ob.py(), "to_reader"))?.call0()?;
587+
ArrowArrayStreamReader::from_pyarrow_bound(&py_reader).ok()
588+
} else {
589+
None
590+
};
591+
592+
match reader_result {
593+
Some(reader) => Ok(reader),
594+
None => Err(PyTypeError::new_err(
595+
"Expected Arrow Table, Arrow RecordBatchReader or other object which conforms to the Arrow ArrayStreamReader protocol.",
596+
)),
597+
}
598+
};
599+
Self::try_from(array_stream_reader?)
600+
.map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
601+
}
602+
}
603+
604+
impl IntoPyArrow for Table {
605+
fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
606+
let module = py.import(intern!(py, "pyarrow"))?;
607+
let class = module.getattr(intern!(py, "Table"))?;
608+
609+
let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
610+
let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));
611+
612+
let kwargs = PyDict::new(py);
613+
kwargs.set_item("schema", py_schema)?;
614+
615+
let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?;
616+
617+
Ok(reader)
618+
}
619+
}
620+
487621
/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
488622
///
489623
/// When wrapped around a type `T: FromPyArrow`, it

0 commit comments

Comments
 (0)