4444//! | `pyarrow.Array` | [ArrayData] |
4545//! | `pyarrow.RecordBatch` | [RecordBatch] |
4646//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) |
47+ //! | `pyarrow.Table` | [Table] (2) |
4748//!
4849//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
4950//! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported
5051//! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically
5152//! easier to create.)
5253//!
53- //! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't
54- //! have these same concepts. A chunked table is instead represented with
55- //! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling
56- //! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader)
57- //! and then importing the reader as a [ArrowArrayStreamReader].
54+ //! (2) Although arrow-rs offers a [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
55+ //! convenience wrapper [Table] (which internally holds `Vec<RecordBatch>`), this is more meant for
56+ //! use cases where you already have `Vec<RecordBatch>` on the Rust side and want to export that in
57+ //! bulk as a `pyarrow.Table`. In general, it is recommended to use streaming approaches instead of
58+ //! dealing with bulk data.
59+ //! For example, a `pyarrow.Table` can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>`
60+ //! instead (since `pyarrow.Table` implements the ArrayStream PyCapsule interface).
5861
5962use std:: convert:: { From , TryFrom } ;
6063use std:: ptr:: { addr_of, addr_of_mut} ;
@@ -68,13 +71,13 @@ use arrow_array::{
6871 make_array,
6972} ;
7073use arrow_data:: ArrayData ;
71- use arrow_schema:: { ArrowError , DataType , Field , Schema } ;
74+ use arrow_schema:: { ArrowError , DataType , Field , Schema , SchemaRef } ;
7275use pyo3:: exceptions:: { PyTypeError , PyValueError } ;
7376use pyo3:: ffi:: Py_uintptr_t ;
74- use pyo3:: import_exception;
7577use pyo3:: prelude:: * ;
7678use pyo3:: pybacked:: PyBackedStr ;
77- use pyo3:: types:: { PyCapsule , PyList , PyTuple } ;
79+ use pyo3:: types:: { PyCapsule , PyDict , PyList , PyTuple } ;
80+ use pyo3:: { import_exception, intern} ;
7881
7982import_exception ! ( pyarrow, ArrowException ) ;
8083/// Represents an exception raised by PyArrow.
@@ -484,6 +487,137 @@ impl IntoPyArrow for ArrowArrayStreamReader {
484487 }
485488}
486489
490+ /// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from
491+ /// and to `pyarrow.Table`.
492+ ///
493+ /// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
494+ /// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
495+ /// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more
496+ /// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust
497+ /// side.
498+ ///
499+ /// ```ignore
500+ /// #[pyfunction]
501+ /// fn return_table(...) -> PyResult<PyArrowType<Table>> {
502+ /// let batches: Vec<RecordBatch>;
503+ /// let schema: SchemaRef;
504+ /// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
505+ /// }
506+ /// ```
507+ #[ derive( Clone ) ]
508+ pub struct Table {
509+ record_batches : Vec < RecordBatch > ,
510+ schema : SchemaRef ,
511+ }
512+
513+ impl Table {
514+ pub unsafe fn new_unchecked ( record_batches : Vec < RecordBatch > , schema : SchemaRef ) -> Self {
515+ Self {
516+ record_batches,
517+ schema,
518+ }
519+ }
520+
521+ pub fn try_new (
522+ record_batches : Vec < RecordBatch > ,
523+ schema : Option < SchemaRef > ,
524+ ) -> Result < Self , ArrowError > {
525+ let schema = match schema {
526+ Some ( s) => s,
527+ None => {
528+ record_batches
529+ . get ( 0 )
530+ . ok_or_else ( || ArrowError :: SchemaError (
531+ "If no schema is supplied explicitly, there must be at least one RecordBatch!" . to_owned ( )
532+ ) ) ?
533+ . schema ( )
534+ . clone ( )
535+ }
536+ } ;
537+ for record_batch in & record_batches {
538+ if schema != record_batch. schema ( ) {
539+ return Err ( ArrowError :: SchemaError (
540+ "All record batches must have the same schema." . to_owned ( ) ,
541+ ) ) ;
542+ }
543+ }
544+ Ok ( Self {
545+ record_batches,
546+ schema,
547+ } )
548+ }
549+
550+ pub fn record_batches ( & self ) -> & [ RecordBatch ] {
551+ & self . record_batches
552+ }
553+
554+ pub fn schema ( & self ) -> SchemaRef {
555+ self . schema . clone ( )
556+ }
557+
558+ pub fn into_inner ( self ) -> ( Vec < RecordBatch > , SchemaRef ) {
559+ ( self . record_batches , self . schema )
560+ }
561+ }
562+
563+ impl TryFrom < ArrowArrayStreamReader > for Table {
564+ type Error = ArrowError ;
565+
566+ fn try_from ( value : ArrowArrayStreamReader ) -> Result < Self , ArrowError > {
567+ let schema = value. schema ( ) ;
568+ let batches = value. collect :: < Result < Vec < _ > , _ > > ( ) ?;
569+ // We assume all batches have the same schema here.
570+ unsafe { Ok ( Self :: new_unchecked ( batches, schema) ) }
571+ }
572+ }
573+
574+ impl FromPyArrow for Table {
575+ fn from_pyarrow_bound ( ob : & Bound < PyAny > ) -> PyResult < Self > {
576+ let array_stream_reader: PyResult < ArrowArrayStreamReader > = {
577+ // First, try whether the object implements the Arrow ArrayStreamReader protocol directly
578+ // (which `pyarrow.Table` does) or test whether it is a RecordBatchReader.
579+ let reader_result = if let Ok ( reader) = ArrowArrayStreamReader :: from_pyarrow_bound ( ob) {
580+ Some ( reader)
581+ }
582+ // If that is not the case, test whether it has a `to_reader` method (which
583+ // `pyarrow.Table` does) whose return value implements the Arrow ArrayStreamReader
584+ // protocol or is a RecordBatchReader.
585+ else if ob. hasattr ( intern ! ( ob. py( ) , "to_reader" ) ) ? {
586+ let py_reader = ob. getattr ( intern ! ( ob. py( ) , "to_reader" ) ) ?. call0 ( ) ?;
587+ ArrowArrayStreamReader :: from_pyarrow_bound ( & py_reader) . ok ( )
588+ } else {
589+ None
590+ } ;
591+
592+ match reader_result {
593+ Some ( reader) => Ok ( reader) ,
594+ None => Err ( PyTypeError :: new_err (
595+ "Expected Arrow Table, Arrow RecordBatchReader or other object which conforms to the Arrow ArrayStreamReader protocol." ,
596+ ) ) ,
597+ }
598+ } ;
599+ Self :: try_from ( array_stream_reader?)
600+ . map_err ( |err| PyErr :: new :: < PyValueError , _ > ( err. to_string ( ) ) )
601+ }
602+ }
603+
604+ impl IntoPyArrow for Table {
605+ fn into_pyarrow ( self , py : Python ) -> PyResult < Bound < PyAny > > {
606+ let module = py. import ( intern ! ( py, "pyarrow" ) ) ?;
607+ let class = module. getattr ( intern ! ( py, "Table" ) ) ?;
608+
609+ let py_batches = PyList :: new ( py, self . record_batches . into_iter ( ) . map ( PyArrowType ) ) ?;
610+ let py_schema = PyArrowType ( Arc :: unwrap_or_clone ( self . schema ) ) ;
611+
612+ let kwargs = PyDict :: new ( py) ;
613+ kwargs. set_item ( "schema" , py_schema) ?;
614+
615+ let reader = class. call_method ( "from_batches" , ( py_batches, ) , Some ( & kwargs) ) ?;
616+
617+ Ok ( reader)
618+ }
619+ }
620+
487621/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
488622///
489623/// When wrapped around a type `T: FromPyArrow`, it
0 commit comments