Skip to content

Commit 032ff40

Browse files
committed
POC for easier column extraction
1 parent f08d5b0 commit 032ff40

File tree

3 files changed

+63
-0
lines changed

3 files changed

+63
-0
lines changed

python/datafusion/dataframe.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,26 @@ def to_arrow_table(self) -> pa.Table:
10681068
"""
10691069
return self.df.to_arrow_table()
10701070

1071+
def to_arrow_array(self) -> pa.ChunkedArray:
1072+
"""Execute the :py:class:`DataFrame` and convert it into an Arrow Array.
1073+
1074+
Only valid when :py:class:`DataFrame` contains a single column.
1075+
1076+
Returns:
1077+
Arrow Array.
1078+
"""
1079+
return self.df.to_arrow_array()
1080+
1081+
def column(self, column_name: str) -> pa.ChunkedArray:
1082+
"""Execute the :py:class:`DataFrame` and convert it into an Arrow Array.
1083+
1084+
Only valid when :py:class:`DataFrame` contains a single column.
1085+
1086+
Returns:
1087+
Arrow Array.
1088+
"""
1089+
return self.df.column(column_name)
1090+
10711091
def execute_stream(self) -> RecordBatchStream:
10721092
"""Executes this DataFrame and returns a stream over a single partition.
10731093

python/tests/test_dataframe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,23 @@ def test_to_arrow_table(df):
16941694
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
16951695

16961696

1697+
def test_to_arrow_array(df):
1698+
# Convert datafusion dataframe to pyarrow Array
1699+
pyarrow_array = df.select("a").to_arrow_array()
1700+
assert isinstance(pyarrow_array, pa.ChunkedArray)
1701+
assert pyarrow_array.to_numpy().shape == (3,)
1702+
1703+
with pytest.raises(ValueError, match="single column"):
1704+
df.to_arrow_array()
1705+
1706+
1707+
def test_column(df):
1708+
# Grab column from datafusion dataframe as pyarrow Array
1709+
pyarrow_array = df.column("a")
1710+
assert isinstance(pyarrow_array, pa.ChunkedArray)
1711+
assert pyarrow_array.to_numpy().shape == (3,)
1712+
1713+
16971714
def test_execute_stream(df):
16981715
stream = df.execute_stream()
16991716
assert all(batch is not None for batch in stream)

src/dataframe.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,32 @@ impl PyDataFrame {
859859
Ok(())
860860
}
861861

862+
/// Extract a single column directly as an Arrow Array
863+
fn column(&self, py: Python<'_>, column_name: &str) -> PyResult<PyObject> {
864+
let single_column_df = self.select(vec![PyExpr::column(column_name)])?;
865+
let array = single_column_df.to_arrow_array(py)?;
866+
Ok(array)
867+
}
868+
869+
/// Convert to Arrow Array
870+
/// Collect the batches and pass to Arrow Array
871+
fn to_arrow_array(&self, py: Python<'_>) -> PyResult<PyObject> {
872+
println!("Converting to Arrow table");
873+
let table = self.to_arrow_table(py)?;
874+
println!("Table");
875+
let args = table.getattr(py, "column_names")?;
876+
let column_names = args.extract::<Bound<PyList>>(py)?;
877+
if column_names.len() != 1 {
878+
return Err(PyValueError::new_err(
879+
"to_arrow_array only supports single column DataFrames",
880+
));
881+
}
882+
print!("Args");
883+
let column_name = column_names.get_item(0)?;
884+
let array: PyObject = table.call_method1(py, "column", (column_name,))?;
885+
Ok(array)
886+
}
887+
862888
/// Convert to Arrow Table
863889
/// Collect the batches and pass to Arrow Table
864890
fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {

0 commit comments

Comments
 (0)