Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions arrow-array/src/ffi_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,9 @@ impl Iterator for ArrowArrayStreamReader {
let result = unsafe {
from_ffi_and_data_type(array, DataType::Struct(self.schema().fields().clone()))
};
Some(result.map(|data| RecordBatch::from(StructArray::from(data))))
Some(result.and_then(|data| {
RecordBatch::try_new(self.schema.clone(), StructArray::from(data).into_parts().1)
}))
} else {
let last_error = self.get_stream_last_error();
let err = ArrowError::CDataInterface(last_error.unwrap());
Expand All @@ -382,6 +384,7 @@ impl RecordBatchReader for ArrowArrayStreamReader {
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;

use arrow_schema::Field;

Expand Down Expand Up @@ -417,11 +420,18 @@ mod tests {
}

fn _test_round_trip_export(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", arrays[0].data_type().clone(), true),
Field::new("b", arrays[1].data_type().clone(), true),
Field::new("c", arrays[2].data_type().clone(), true),
]));
let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
let schema = Arc::new(Schema::new_with_metadata(
vec![
Field::new("a", arrays[0].data_type().clone(), true)
.with_metadata(metadata.clone()),
Field::new("b", arrays[1].data_type().clone(), true)
.with_metadata(metadata.clone()),
Field::new("c", arrays[2].data_type().clone(), true)
.with_metadata(metadata.clone()),
],
metadata,
));
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;

Expand Down Expand Up @@ -452,7 +462,11 @@ mod tests {

let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap();

let record_batch = RecordBatch::from(StructArray::from(array));
let record_batch = RecordBatch::try_new(
SchemaRef::from(exported_schema.clone()),
StructArray::from(array).into_parts().1,
)
.unwrap();
produced_batches.push(record_batch);
}

Expand All @@ -462,11 +476,18 @@ mod tests {
}

fn _test_round_trip_import(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", arrays[0].data_type().clone(), true),
Field::new("b", arrays[1].data_type().clone(), true),
Field::new("c", arrays[2].data_type().clone(), true),
]));
let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
let schema = Arc::new(Schema::new_with_metadata(
vec![
Field::new("a", arrays[0].data_type().clone(), true)
.with_metadata(metadata.clone()),
Field::new("b", arrays[1].data_type().clone(), true)
.with_metadata(metadata.clone()),
Field::new("c", arrays[2].data_type().clone(), true)
.with_metadata(metadata.clone()),
],
metadata,
));
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;

Expand Down
28 changes: 16 additions & 12 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def test_empty_recordbatch_with_row_count():
"""

# Create an empty schema with no fields
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}).select([])
batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}, metadata={b'key1': b'value1'}).select([])
num_rows = 4
assert batch.num_rows == num_rows
assert batch.num_columns == 0
Expand All @@ -545,7 +545,7 @@ def test_record_batch_reader():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
Expand All @@ -571,7 +571,7 @@ def test_record_batch_reader_pycapsule():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
Expand Down Expand Up @@ -621,7 +621,7 @@ def test_record_batch_pycapsule():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batch = pa.record_batch([[[1], [2, 42]]], schema)
wrapped = StreamWrapper(batch)
b = rust.round_trip_record_batch_reader(wrapped)
Expand All @@ -640,7 +640,7 @@ def test_table_pycapsule():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
Expand All @@ -650,55 +650,59 @@ def test_table_pycapsule():
b = rust.round_trip_record_batch_reader(wrapped)
new_table = b.read_all()

assert table.schema == new_table.schema
assert table == new_table
assert table.schema == new_table.schema
assert table.schema.metadata == new_table.schema.metadata
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_empty():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
table = pa.Table.from_batches([], schema=schema)
new_table = rust.build_table([], schema=schema)

assert table.schema == new_table.schema
assert table == new_table
assert table.schema == new_table.schema
assert table.schema.metadata == new_table.schema.metadata
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_roundtrip():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))])
schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches, schema=schema)
new_table = rust.round_trip_table(table)

assert table.schema == new_table.schema
assert table == new_table
assert table.schema == new_table.schema
assert table.schema.metadata == new_table.schema.metadata
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_from_batches():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches)
new_table = rust.build_table(batches, schema)

assert table.schema == new_table.schema
assert table == new_table
assert table.schema == new_table.schema
assert table.schema.metadata == new_table.schema.metadata
assert len(table.to_batches()) == len(new_table.to_batches())


Expand Down
Loading