Skip to content

Commit

Permalink
Extends PhysicalExtensionCodec for udwf
Browse files Browse the repository at this point in the history
  • Loading branch information
jcsherin committed Nov 14, 2024
1 parent 30538fb commit 46a00e8
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 5 deletions.
60 changes: 57 additions & 3 deletions datafusion/proto/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::{DataType, Field};
use std::any::Any;

use arrow::datatypes::DataType;
use std::fmt::Debug;

use datafusion_common::plan_err;
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature, Volatility,
Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl,
Signature, Volatility, WindowUDFImpl,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;

mod roundtrip_logical_plan;
mod roundtrip_physical_plan;
Expand Down Expand Up @@ -125,3 +128,54 @@ pub struct MyAggregateUdfNode {
#[prost(string, tag = "1")]
pub result: String,
}

#[derive(Debug)]
pub(in crate::cases) struct CustomUDWF {
signature: Signature,
payload: String,
}

impl CustomUDWF {
pub fn new(payload: String) -> Self {
Self {
signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
payload,
}
}
}

impl WindowUDFImpl for CustomUDWF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"custom_udwf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn partition_evaluator(
&self,
_partition_evaluator_args: PartitionEvaluatorArgs,
) -> datafusion_common::Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(CustomUDWFEvaluator {}))
}

fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result<Field> {
Ok(Field::new(field_args.name(), DataType::UInt64, false))
}
}

#[derive(Debug)]
struct CustomUDWFEvaluator;

impl PartitionEvaluator for CustomUDWFEvaluator {}

#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CustomUDWFNode {
#[prost(string, tag = "1")]
pub payload: String,
}
83 changes: 81 additions & 2 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ use datafusion_functions_aggregate::array_agg::array_agg_udaf;
use datafusion_functions_aggregate::min_max::max_udaf;
use prost::Message;

use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode};
use crate::cases::{
CustomUDWF, CustomUDWFNode, MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf,
MyRegexUdfNode,
};
use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::compute::kernels::sort::SortOptions;
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
Expand Down Expand Up @@ -94,7 +97,7 @@ use datafusion_common::{
};
use datafusion_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF,
Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound,
Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, WindowUDF,
};
use datafusion_functions_aggregate::average::avg_udaf;
use datafusion_functions_aggregate::nth_value::nth_value_udaf;
Expand Down Expand Up @@ -1016,6 +1019,33 @@ impl PhysicalExtensionCodec for UDFExtensionCodec {
}
Ok(())
}

fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> {
if name == "custom_udwf" {
let proto = CustomUDWFNode::decode(buf).map_err(|err| {
DataFusionError::Internal(format!("failed to decode custom_udwf: {err}"))
})?;

Ok(Arc::new(WindowUDF::from(CustomUDWF::new(proto.payload))))
} else {
not_impl_err!(
"unrecognized user-defined window function implementation, cannot decode"
)
}
}

fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> {
let binding = node.inner();
if let Some(udwf) = binding.as_any().downcast_ref::<CustomUDWF>() {
let proto = CustomUDWFNode {
payload: udwf.payload.clone(),
};
proto.encode(buf).map_err(|err| {
DataFusionError::Internal(format!("failed to encode udwf: {err:?}"))
})?;
}
Ok(())
}
}

#[test]
Expand Down Expand Up @@ -1073,6 +1103,55 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
Ok(())
}

#[test]
fn roundtrip_udwf_extension_codec() -> Result<()> {
let field_a = Field::new("a", DataType::Int64, false);
let field_b = Field::new("b", DataType::Int64, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));

let custom_udwf = Arc::new(WindowUDF::from(CustomUDWF::new("payload".to_string())));
let udwf = create_udwf_window_expr(
&custom_udwf,
&[col("a", &schema)?],
schema.as_ref(),
"custom_udwf(a) PARTITION BY [b] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(),
false,
)?;

let window_frame = WindowFrame::new_bounds(
datafusion_expr::WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::Int64(None)),
WindowFrameBound::CurrentRow,
);

let udwf_expr = Arc::new(BuiltInWindowExpr::new(
udwf,
&[col("b", &schema)?],
&LexOrdering {
inner: vec![PhysicalSortExpr {
expr: col("a", &schema)?,
options: SortOptions {
descending: false,
nulls_first: false,
},
}],
},
Arc::new(window_frame),
));

let input = Arc::new(EmptyExec::new(schema.clone()));
let window = Arc::new(BoundedWindowAggExec::try_new(
vec![udwf_expr],
input,
vec![col("b", &schema)?],
InputOrderMode::Sorted,
)?);

let ctx = SessionContext::new();
roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec)?;
Ok(())
}

#[test]
fn roundtrip_aggregate_udf_extension_codec() -> Result<()> {
let field_text = Field::new("text", DataType::Utf8, true);
Expand Down

0 comments on commit 46a00e8

Please sign in to comment.