Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyCapsule support for Arrow import and export #825

Merged
merged 19 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support Arrow PyCapsule for reading arrow tables and for exporting da…
…taframes
  • Loading branch information
timsaucer committed Aug 28, 2024
commit 4a9335644ad2bc531df39df94f21c73ced6ac175
16 changes: 16 additions & 0 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,19 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram
"""
columns = [c for c in columns]
return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls))

def __arrow_c_stream__(self, requested_schema: pa.Schema) -> Any:
"""Export an Arrow PyCapsule Stream.

This will execute and collect the DataFrame. We will attempt to respect the
requested schema, but only trivial transformations will be applied such as only
returning the fields listed in the requested schema if their data types match
those in the DataFrame.
timsaucer marked this conversation as resolved.
Show resolved Hide resolved

Args:
requested_schema: Attempt to provide the DataFrame using this schema.

Returns:
Arrow PyCapsule object.
"""
return self.df.__arrow_c_stream__(requested_schema)
38 changes: 29 additions & 9 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ use std::str::FromStr;
use std::sync::Arc;

use datafusion::execution::session_state::SessionStateBuilder;
use arrow::array::RecordBatchReader;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::FromPyArrow;
use object_store::ObjectStore;
use url::Url;
use uuid::Uuid;

use pyo3::exceptions::{PyKeyError, PyValueError};
use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError};
use pyo3::prelude::*;

use crate::catalog::{PyCatalog, PyTable};
Expand Down Expand Up @@ -474,18 +477,35 @@ impl PySessionContext {
name: Option<&str>,
py: Python,
) -> PyResult<PyDataFrame> {
// Instantiate pyarrow Table object & convert to batches
let table = data.call_method0("to_batches")?;
let mut batches = None;
let mut schema = None;

let schema = data.getattr("schema")?;
let schema = schema.extract::<PyArrowType<Schema>>()?;
if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
// Works for any object that implements __arrow_c_stream__ in pycapsule.

schema = Some(stream_reader.schema().as_ref().to_owned());
batches = Some(stream_reader.filter_map(|v| v.ok()).collect());
timsaucer marked this conversation as resolved.
Show resolved Hide resolved
} else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
// While this says RecordBatch, it will work for any object that implements
// __arrow_c_array__ in pycapsule.
timsaucer marked this conversation as resolved.
Show resolved Hide resolved

schema = Some(array.schema().as_ref().to_owned());
batches = Some(vec![array]);
}

if batches.is_none() || schema.is_none() {
return Err(PyTypeError::new_err(
"Expected either a Arrow Array or Arrow Stream in from_arrow_table().",
));
}

let batches = batches.unwrap();
let schema = schema.unwrap();
timsaucer marked this conversation as resolved.
Show resolved Hide resolved

// Cast PyAny to RecordBatch type
// Because create_dataframe() expects a vector of vectors of record batches
// here we need to wrap the vector of record batches in an additional vector
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>()?;
let list_of_batches = PyArrowType::from(vec![batches.0]);
self.create_dataframe(list_of_batches, name, Some(schema), py)
let list_of_batches = PyArrowType::from(vec![batches]);
self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
}

/// Construct datafusion dataframe from pandas
Expand Down
25 changes: 24 additions & 1 deletion src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
// specific language governing permissions and limitations
// under the License.

use std::ffi::CString;
use std::sync::Arc;

use arrow::array::{RecordBatchIterator, RecordBatchReader};
use arrow::ffi_stream::FFI_ArrowArrayStream;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::arrow::util::pretty;
Expand All @@ -29,7 +32,7 @@ use datafusion_common::UnnestOptions;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyTuple;
use pyo3::types::{PyCapsule, PyTuple};
use tokio::task::JoinHandle;

use crate::errors::py_datafusion_err;
Expand Down Expand Up @@ -451,6 +454,26 @@ impl PyDataFrame {
Ok(table)
}

#[allow(unused_variables)]
fn __arrow_c_stream__<'py>(
&'py mut self,
py: Python<'py>,
requested_schema: Option<Bound<'py, PyCapsule>>,
) -> PyResult<Bound<'py, PyCapsule>> {
let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
.into_iter()
.map(|r| Ok(r));
timsaucer marked this conversation as resolved.
Show resolved Hide resolved
let schema = self.df.schema().to_owned().into();

// let reader = RecordBatchIterator::new(vec![Ok(self.clone())], self.schema());
let reader = RecordBatchIterator::new(batches, schema);
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);

let ffi_stream = FFI_ArrowArrayStream::new(reader);
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
PyCapsule::new_bound(py, ffi_stream, Some(stream_capsule_name))
}

fn execute_stream(&self, py: Python) -> PyResult<PyRecordBatchStream> {
// create a Tokio runtime to run the async code
let rt = &get_tokio_runtime(py).0;
Expand Down