Skip to content
49 changes: 37 additions & 12 deletions arrow-array/src/ffi_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,13 @@ impl Iterator for ArrowArrayStreamReader {
let result = unsafe {
from_ffi_and_data_type(array, DataType::Struct(self.schema().fields().clone()))
};
Some(result.map(|data| RecordBatch::from(StructArray::from(data))))
Some(result.map(|data| {
let struct_array = StructArray::from(data);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please leave a comment here explaining:

  1. The rationale for this form rather than just converting StructArray to RecordBatch
  2. A SAFETY comment explaining why the unsafe is ok (aka how the invariants required for RecordBatch::new_unchecked are satisfied)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale for this form rather than just converting StructArray to RecordBatch

Basically what I explained here (#8790 (comment)) => StructArray alone by definition is metadata-less, in turn leading to the problem that the resulting RecordBatch won't have any metadata attached if you just return it as-is.

I'm not sure whether there is another more elegant way to construct a RecordBatch with corresponding metadata from ArrayData. Right now I'm going through StructArray because the previous interface did that too. If there is another more elegant way, please let me know.

Other ways to attach metadata to an existing RecordBatch would be, as far as I can see, to call with_schema() (which will incure some "is subschema test" costs) or somehow through schema_metadata_mut(), but the interface feels a bit clunky for this specific task IMHO.

A SAFETY comment explaining why the unsafe is ok (aka how the invariants required for RecordBatch::new_unchecked are satisfied)

One reason for the unsafe here is that I did not want to introduce performance penalties in comparison to what the interface did before (it just returned RecordBatch without checking whether it's actually corresponding to the schema of ArrowArrayStreamReader; and the schemas actually mismatched before my change, at least metadata-wise).

In principle Iterator of ArrowArrayStreamReader returns Result, so we can make this fallible through RecordBatch::try_new(...). This would incur some costs though, such as checking each column for correct nullability, equal and correct row count, type checks, etc..

I would have guessed that at least data-wise the interface can be trusted and therefore the checks can be omitted? 😅 I'm really not the expert here, I would have assumed that someone from the arrow-rs team could have some opinion here 😬

let row_count = struct_array.len();
let (_, arrays, _) = struct_array.into_parts();

unsafe { RecordBatch::new_unchecked(self.schema.clone(), arrays, row_count) }
}))
} else {
let last_error = self.get_stream_last_error();
let err = ArrowError::CDataInterface(last_error.unwrap());
Expand All @@ -382,6 +388,7 @@ impl RecordBatchReader for ArrowArrayStreamReader {
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;

use arrow_schema::Field;

Expand Down Expand Up @@ -417,11 +424,14 @@ mod tests {
}

fn _test_round_trip_export(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", arrays[0].data_type().clone(), true),
Field::new("b", arrays[1].data_type().clone(), true),
Field::new("c", arrays[2].data_type().clone(), true),
]));
let schema = Arc::new(Schema::new_with_metadata(
vec![
Field::new("a", arrays[0].data_type().clone(), true),
Field::new("b", arrays[1].data_type().clone(), true),
Field::new("c", arrays[2].data_type().clone(), true),
],
HashMap::from([("foo".to_owned(), "bar".to_owned())]),
));
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;

Expand Down Expand Up @@ -452,7 +462,19 @@ mod tests {

let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap();

let record_batch = RecordBatch::from(StructArray::from(array));
let record_batch = {
let struct_array = StructArray::from(array);
let row_count = struct_array.len();
let (_, arrays, _) = struct_array.into_parts();

unsafe {
RecordBatch::new_unchecked(
SchemaRef::from(exported_schema.clone()),
arrays,
row_count,
)
}
};
produced_batches.push(record_batch);
}

Expand All @@ -462,11 +484,14 @@ mod tests {
}

fn _test_round_trip_import(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", arrays[0].data_type().clone(), true),
Field::new("b", arrays[1].data_type().clone(), true),
Field::new("c", arrays[2].data_type().clone(), true),
]));
let schema = Arc::new(Schema::new_with_metadata(
vec![
Field::new("a", arrays[0].data_type().clone(), true),
Field::new("b", arrays[1].data_type().clone(), true),
Field::new("c", arrays[2].data_type().clone(), true),
],
HashMap::from([("foo".to_owned(), "bar".to_owned())]),
));
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;

Expand Down
24 changes: 23 additions & 1 deletion arrow-pyarrow-integration-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, ToPyArrow};
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow};
use arrow::record_batch::RecordBatch;

fn to_py_err(err: ArrowError) -> PyErr {
Expand Down Expand Up @@ -140,6 +140,26 @@ fn round_trip_record_batch_reader(
Ok(obj)
}

#[pyfunction]
fn round_trip_table(obj: PyArrowType<Table>) -> PyResult<PyArrowType<Table>> {
Ok(obj)
}

/// Builds a Table from a list of RecordBatches and a Schema.
#[pyfunction]
pub fn build_table(
record_batches: Vec<PyArrowType<RecordBatch>>,
schema: PyArrowType<Schema>,
) -> PyResult<PyArrowType<Table>> {
Ok(PyArrowType(
Table::try_new(
record_batches.into_iter().map(|rb| rb.0).collect(),
Arc::new(schema.0),
)
.map_err(to_py_err)?,
))
}

#[pyfunction]
fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()> {
// This makes sure we can correctly consume a RBR and return the error,
Expand Down Expand Up @@ -178,6 +198,8 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &Bound<PyModule>) -> PyResu
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
m.add_wrapped(wrap_pyfunction!(round_trip_table))?;
m.add_wrapped(wrap_pyfunction!(build_table))?;
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
Ok(())
Expand Down
110 changes: 99 additions & 11 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import datetime
import decimal
import string
from typing import Union, Tuple, Protocol

import pytest
import pyarrow as pa
Expand Down Expand Up @@ -120,28 +121,50 @@ def assert_pyarrow_leak():
# This defines that Arrow consumers should allow any object that has specific "dunder"
# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle
# _any_ class, without pyarrow-specific handling.
class SchemaWrapper:
def __init__(self, schema):


class ArrowSchemaExportable(Protocol):
def __arrow_c_schema__(self) -> object: ...


class ArrowArrayExportable(Protocol):
def __arrow_c_array__(
self,
requested_schema: Union[object, None] = None
) -> Tuple[object, object]:
...


class ArrowStreamExportable(Protocol):
def __arrow_c_stream__(
self,
requested_schema: Union[object, None] = None
) -> object:
...


class SchemaWrapper(ArrowSchemaExportable):
def __init__(self, schema: ArrowSchemaExportable) -> None:
self.schema = schema

def __arrow_c_schema__(self):
def __arrow_c_schema__(self) -> object:
return self.schema.__arrow_c_schema__()


class ArrayWrapper:
def __init__(self, array):
class ArrayWrapper(ArrowArrayExportable):
def __init__(self, array: ArrowArrayExportable) -> None:
self.array = array

def __arrow_c_array__(self):
return self.array.__arrow_c_array__()
def __arrow_c_array__(self, requested_schema: Union[object, None] = None) -> Tuple[object, object]:
return self.array.__arrow_c_array__(requested_schema=requested_schema)


class StreamWrapper:
def __init__(self, stream):
class StreamWrapper(ArrowStreamExportable):
def __init__(self, stream: ArrowStreamExportable) -> None:
self.stream = stream

def __arrow_c_stream__(self):
return self.stream.__arrow_c_stream__()
def __arrow_c_stream__(self, requested_schema: Union[object, None] = None) -> object:
return self.stream.__arrow_c_stream__(requested_schema=requested_schema)


@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
Expand Down Expand Up @@ -613,6 +636,71 @@ def test_table_pycapsule():
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_empty():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
table = pa.Table.from_batches([], schema=schema)
new_table = rust.build_table([], schema=schema)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_roundtrip():
"""
Python -> Rust -> Python
"""
metadata = {b'key1': b'value1'}
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata=metadata)
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches, schema=schema)
# TODO: Remove these `assert`s as soon as the metadata issue is solved in Rust
assert table.schema.metadata == metadata
assert all(batch.schema.metadata == metadata for batch in table.to_batches())
new_table = rust.round_trip_table(table)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_from_batches():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches)
new_table = rust.build_table(batches, schema)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_error_inconsistent_schema():
"""
Python -> Rust -> Python
"""
schema_1 = pa.schema([('ints', pa.list_(pa.int32()))])
schema_2 = pa.schema([('floats', pa.list_(pa.float32()))])
batches = [
pa.record_batch([[[1], [2, 42]]], schema_1),
pa.record_batch([[None, [], [5.6, 6.4]]], schema_2),
]
with pytest.raises(pa.ArrowException, match="Schema error: All record batches must have the same schema."):
rust.build_table(batches, schema_1)


def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]
Expand Down
Loading
Loading