Skip to content

feat: Add Aggregate UDF to FFI crate #14775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
83bce67
Work in progress adding user defined aggregate function FFI support
timsaucer Feb 15, 2025
efa2e5c
Intermediate work. Going through groups accumulator
timsaucer Feb 18, 2025
53fafc9
MVP for aggregate udf via FFI
timsaucer Feb 19, 2025
95c3c79
Clean up after rebase
timsaucer Feb 20, 2025
a91ee5b
Add unit test for FFI Accumulator Args
timsaucer Feb 22, 2025
d15b3a1
Adding unit tests and fixing memory errors in aggregate ffi udf
timsaucer Feb 23, 2025
eb6a072
Working through additional unit and integration tests for UDAF ffi
timsaucer Feb 23, 2025
9d31d1f
Switch to a accumulator that supports convert to state to get a littl…
timsaucer Feb 23, 2025
217bb8e
Set feature so we do not get an error warning in stable rustc
timsaucer Feb 23, 2025
b5b11d4
Add more options to test
timsaucer Feb 24, 2025
ae61d88
Add unit test for FFI RecordBatchStream
timsaucer Feb 25, 2025
dfd3268
Add a few more args to ffi accumulator test fn
timsaucer Feb 25, 2025
aa4b7ce
Adding more unit tests on ffi aggregate udaf
timsaucer Feb 27, 2025
2de6d0a
taplo format
timsaucer Feb 27, 2025
11f88de
Update code comment
timsaucer Feb 27, 2025
9300ba5
Correct function name
timsaucer Feb 27, 2025
ec05091
Temp fix record batch test dependencies
crystalxyz Apr 1, 2025
fc40bc0
Address some comments
crystalxyz Apr 2, 2025
4d164cc
Revise comments and address PR comments
crystalxyz Apr 8, 2025
173d7c0
Remove commented code
crystalxyz Apr 8, 2025
45ea283
Refactor GroupsAccumulator
crystalxyz Apr 8, 2025
b6da0e9
Add documentation
crystalxyz Apr 12, 2025
a2fbda8
Split integration tests
crystalxyz Apr 12, 2025
0b4a8f5
Address comments to refactor error handling for opt filter
crystalxyz Apr 12, 2025
4b3c533
Fix linting errors
crystalxyz Apr 12, 2025
1b85dd9
Fix linting and add deref
crystalxyz Apr 12, 2025
8a4de4a
Remove extra tests and unnecessary code
crystalxyz Jun 3, 2025
5e02c72
Adjustments to FFI aggregate functions after rebase on main
timsaucer Jun 4, 2025
d128b85
cargo fmt
timsaucer Jun 4, 2025
4282c2a
cargo clippy
timsaucer Jun 4, 2025
bb84d08
Re-implement cleaned up code that was removed in last push
timsaucer Jun 4, 2025
1c3fad9
Minor review comments
timsaucer Jun 4, 2025
e63af82
Merge branch 'main' into feat/aggregate-udf-ffi
alamb Jun 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,5 @@ unnecessary_lazy_evaluations = "warn"
uninlined_format_args = "warn"

[workspace.lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] }
unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)", "cfg(tarpaulin_include)"] }
unused_qualifications = "deny"
49 changes: 5 additions & 44 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,8 @@ pub(crate) mod test_util {
mod tests {

use std::fmt::{self, Display, Formatter};
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use crate::datasource::file_format::parquet::test_util::store_parquet;
Expand All @@ -120,7 +118,7 @@ mod tests {
use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext};

use arrow::array::RecordBatch;
use arrow_schema::{Schema, SchemaRef};
use arrow_schema::Schema;
use datafusion_catalog::Session;
use datafusion_common::cast::{
as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array,
Expand All @@ -140,7 +138,7 @@ mod tests {
};
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::{RecordBatchStream, TaskContext};
use datafusion_execution::TaskContext;
use datafusion_expr::dml::InsertOp;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{collect, ExecutionPlan};
Expand All @@ -153,7 +151,7 @@ mod tests {
use async_trait::async_trait;
use datafusion_datasource::file_groups::FileGroup;
use futures::stream::BoxStream;
use futures::{Stream, StreamExt};
use futures::StreamExt;
use insta::assert_snapshot;
use log::error;
use object_store::local::LocalFileSystem;
Expand All @@ -169,6 +167,8 @@ mod tests {
use parquet::format::FileMetaData;
use tokio::fs::File;

use crate::test_util::bounded_stream;

enum ForceViews {
Yes,
No,
Expand Down Expand Up @@ -1662,43 +1662,4 @@ mod tests {

Ok(())
}

/// Creates an bounded stream for testing purposes.
fn bounded_stream(
batch: RecordBatch,
limit: usize,
) -> datafusion_execution::SendableRecordBatchStream {
Box::pin(BoundedStream {
count: 0,
limit,
batch,
})
}

struct BoundedStream {
limit: usize,
count: usize,
batch: RecordBatch,
}

impl Stream for BoundedStream {
type Item = Result<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.count >= self.limit {
return Poll::Ready(None);
}
self.count += 1;
Poll::Ready(Some(Ok(self.batch.clone())))
}
}

impl RecordBatchStream for BoundedStream {
fn schema(&self) -> SchemaRef {
self.batch.schema()
}
}
}
47 changes: 47 additions & 0 deletions datafusion/core/src/test_util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ pub mod parquet;

pub mod csv;

use futures::Stream;
use std::any::Any;
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::catalog::{TableProvider, TableProviderFactory};
use crate::dataframe::DataFrame;
Expand All @@ -38,11 +40,13 @@ use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
use crate::physical_plan::ExecutionPlan;
use crate::prelude::{CsvReadOptions, SessionContext};

use crate::execution::SendableRecordBatchStream;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_catalog::Session;
use datafusion_common::TableReference;
use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType};
use std::pin::Pin;

use async_trait::async_trait;

Expand All @@ -52,6 +56,8 @@ use tempfile::TempDir;
pub use datafusion_common::test_util::parquet_test_data;
pub use datafusion_common::test_util::{arrow_test_data, get_data_dir};

use crate::execution::RecordBatchStream;

/// Scan an empty data source, mainly used in tests
pub fn scan_empty(
name: Option<&str>,
Expand Down Expand Up @@ -234,3 +240,44 @@ pub fn register_unbounded_file_with_ordering(
ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?;
Ok(())
}

/// Creates a bounded stream that emits the same record batch a specified number of times.
/// This is useful for testing purposes.
pub fn bounded_stream(
record_batch: RecordBatch,
limit: usize,
) -> SendableRecordBatchStream {
Box::pin(BoundedStream {
record_batch,
count: 0,
limit,
})
}

struct BoundedStream {
record_batch: RecordBatch,
count: usize,
limit: usize,
}

impl Stream for BoundedStream {
type Item = Result<RecordBatch, crate::error::DataFusionError>;

fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.count >= self.limit {
Poll::Ready(None)
} else {
self.count += 1;
Poll::Ready(Some(Ok(self.record_batch.clone())))
}
}
}

impl RecordBatchStream for BoundedStream {
fn schema(&self) -> SchemaRef {
self.record_batch.schema()
}
}
2 changes: 1 addition & 1 deletion datafusion/expr-common/src/groups_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow::array::{ArrayRef, BooleanArray};
use datafusion_common::{not_impl_err, Result};

/// Describes how many rows should be emitted during grouping.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmitTo {
/// Emit all groups
All,
Expand Down
3 changes: 3 additions & 0 deletions datafusion/ffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ arrow-schema = { workspace = true }
async-ffi = { version = "0.5.0", features = ["abi_stable"] }
async-trait = { workspace = true }
datafusion = { workspace = true, default-features = false }
datafusion-functions-aggregate-common = { workspace = true }
datafusion-proto = { workspace = true }
datafusion-proto-common = { workspace = true }
futures = { workspace = true }
log = { workspace = true }
prost = { workspace = true }
Expand All @@ -56,3 +58,4 @@ doc-comment = { workspace = true }

[features]
integration-tests = []
tarpaulin_include = [] # Exists only to prevent warnings on stable and still have accurate coverage
33 changes: 24 additions & 9 deletions datafusion/ffi/src/arrow_wrappers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ use abi_stable::StableAbi;
use arrow::{
array::{make_array, ArrayRef},
datatypes::{Schema, SchemaRef},
ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema},
error::ArrowError,
ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema},
};
use log::error;

Expand All @@ -44,16 +45,19 @@ impl From<SchemaRef> for WrappedSchema {
WrappedSchema(ffi_schema)
}
}
/// Some functions are expected to always succeed, like getting the schema from a TableProvider.
/// Since going through the FFI always has the potential to fail, we need to catch these errors,
/// give the user a warning, and return some kind of result. In this case we default to an
/// empty schema.
#[cfg(not(tarpaulin_include))]
fn catch_df_schema_error(e: ArrowError) -> Schema {
error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {e}");
Schema::empty()
}

impl From<WrappedSchema> for SchemaRef {
fn from(value: WrappedSchema) -> Self {
let schema = match Schema::try_from(&value.0) {
Ok(s) => s,
Err(e) => {
error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {e}");
Schema::empty()
}
};
let schema = Schema::try_from(&value.0).unwrap_or_else(catch_df_schema_error);
Arc::new(schema)
}
}
Expand All @@ -71,11 +75,22 @@ pub struct WrappedArray {
}

impl TryFrom<WrappedArray> for ArrayRef {
type Error = arrow::error::ArrowError;
type Error = ArrowError;

fn try_from(value: WrappedArray) -> Result<Self, Self::Error> {
let data = unsafe { from_ffi(value.array, &value.schema.0)? };

Ok(make_array(data))
}
}

impl TryFrom<&ArrayRef> for WrappedArray {
type Error = ArrowError;

fn try_from(array: &ArrayRef) -> Result<Self, Self::Error> {
let (array, schema) = to_ffi(&array.to_data())?;
let schema = WrappedSchema(schema);

Ok(WrappedArray { array, schema })
}
}
1 change: 1 addition & 0 deletions datafusion/ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub mod schema_provider;
pub mod session_config;
pub mod table_provider;
pub mod table_source;
pub mod udaf;
pub mod udf;
pub mod udtf;
pub mod util;
Expand Down
14 changes: 11 additions & 3 deletions datafusion/ffi/src/plan_properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,10 @@ impl From<FFI_EmissionType> for EmissionType {

#[cfg(test)]
mod tests {
use datafusion::physical_plan::Partitioning;
use datafusion::{
physical_expr::{LexOrdering, PhysicalSortExpr},
physical_plan::Partitioning,
};

use super::*;

Expand All @@ -311,8 +314,13 @@ mod tests {
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)]));

let original_props = PlanProperties::new(
EquivalenceProperties::new(schema),
Partitioning::UnknownPartitioning(3),
EquivalenceProperties::new(Arc::clone(&schema)).with_reorder(
LexOrdering::new(vec![PhysicalSortExpr {
expr: datafusion::physical_plan::expressions::col("a", &schema)?,
options: Default::default(),
}]),
),
Partitioning::RoundRobinBatch(3),
EmissionType::Incremental,
Boundedness::Bounded,
);
Expand Down
46 changes: 46 additions & 0 deletions datafusion/ffi/src/record_batch_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,49 @@ impl Stream for FFI_RecordBatchStream {
}
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion::{
common::record_batch, error::Result, execution::SendableRecordBatchStream,
test_util::bounded_stream,
};

use super::FFI_RecordBatchStream;
use futures::StreamExt;

#[tokio::test]
async fn test_round_trip_record_batch_stream() -> Result<()> {
let record_batch = record_batch!(
("a", Int32, vec![1, 2, 3]),
("b", Float64, vec![Some(4.0), None, Some(5.0)])
)?;
let original_rbs = bounded_stream(record_batch.clone(), 1);

let ffi_rbs: FFI_RecordBatchStream = original_rbs.into();
let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs);

let schema = ffi_rbs.schema();
assert_eq!(
schema,
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Float64, true)
]))
);

let batch = ffi_rbs.next().await;
assert!(batch.is_some());
assert!(batch.as_ref().unwrap().is_ok());
assert_eq!(batch.unwrap().unwrap(), record_batch);

// There should only be one batch
let no_batch = ffi_rbs.next().await;
assert!(no_batch.is_none());

Ok(())
}
}
Loading
Loading