Skip to content

Commit 52047b8

Browse files
committed
Fall back to pyarrow.Table.to_batches() method since it preserves metadata
1 parent bf20c77 commit 52047b8

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

arrow-pyarrow/src/lib.rs

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -515,34 +515,14 @@ impl Table {
515515
record_batches: Vec<RecordBatch>,
516516
schema: SchemaRef,
517517
) -> Result<Self, ArrowError> {
518-
/// This function was copied from `pyo3_arrow/utils.rs` for now. I don't understand yet why
519-
/// this is required instead of a "normal" `schema == record_batch.schema()` check.
520-
///
521-
/// TODO: Either remove this check, replace it with something already existing in `arrow-rs`
522-
/// or move it to a central `utils` location.
523-
fn schema_equals(left: &SchemaRef, right: &SchemaRef) -> bool {
524-
left.fields
525-
.iter()
526-
.zip(right.fields.iter())
527-
.all(|(left_field, right_field)| {
528-
left_field.name() == right_field.name()
529-
&& left_field
530-
.data_type()
531-
.equals_datatype(right_field.data_type())
532-
})
533-
}
534-
535518
for record_batch in &record_batches {
536-
if !schema_equals(&schema, &record_batch.schema()) {
537-
return Err(ArrowError::SchemaError(
538-
//"All record batches must have the same schema.".to_owned(),
539-
format!(
540-
"All record batches must have the same schema. \
519+
if schema != record_batch.schema() {
520+
return Err(ArrowError::SchemaError(format!(
521+
"All record batches must have the same schema. \
541522
Expected schema: {:?}, got schema: {:?}",
542-
schema,
543-
record_batch.schema()
544-
),
545-
));
523+
schema,
524+
record_batch.schema()
525+
)));
546526
}
547527
}
548528
Ok(Self {
@@ -577,6 +557,26 @@ impl TryFrom<Box<dyn RecordBatchReader>> for Table {
577557
/// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`]
578558
impl FromPyArrow for Table {
579559
fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
560+
// Try to use to_batches() method if available (e.g., for `pyarrow.Table`)
561+
if ob.hasattr("to_batches")? {
562+
let batches_list = ob.call_method0("to_batches")?;
563+
564+
// Convert Python list of `pyarrow.RecordBatches` to `Vec<RecordBatch>`
565+
if let Ok(list) = batches_list.downcast::<PyList>() {
566+
let batches = list
567+
.iter()
568+
.map(|value| RecordBatch::from_pyarrow_bound(&value))
569+
.collect::<Result<Vec<_>, _>>()?;
570+
571+
// Extract schema from the table
572+
let py_schema = ob.getattr("schema")?;
573+
let schema = Arc::new(Schema::from_pyarrow_bound(&py_schema)?);
574+
575+
return Self::try_new(batches, schema)
576+
.map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()));
577+
}
578+
}
579+
580580
let reader: Box<dyn RecordBatchReader> =
581581
Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
582582
Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))

0 commit comments

Comments
 (0)