Skip to content

Commit c06664e

Browse files
committed
Backport: Implement equals for stateful functions (apache#16781)
* Implement equals for stateful functions Default implementation of `ScalarUDFImpl::equals`, `AggregateUDFImpl::equals` and `WindowUDFImpl::equals` is correct for stateless functions and those which only state is the `Signature`, which is most of the functions. This implements `equals` and `hash_value` for functions that have state other than `Signature` object. This fixes correctness issues which could occur when such stateful functions are used together in one query. * downgrade for MSRV * Improve doc * Update default UDF:: equals to compare aliases too * Update default UDF:: equals to compare type too (‼️) * remove now-obsoleted UDF equals/hash customizations remove these which compare signature, aliases, as the default handles these now * remove equals impl which compares name, signature -- default does that * cleanup imports (cherry picked from commit afd8235)
1 parent 887f2af commit c06664e

File tree

19 files changed

+580
-23
lines changed

19 files changed

+580
-23
lines changed

datafusion-examples/examples/function_factory.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use datafusion::logical_expr::{
2828
ColumnarValue, CreateFunction, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
2929
Signature, Volatility,
3030
};
31+
use std::hash::{DefaultHasher, Hash, Hasher};
3132
use std::result::Result as RResult;
3233
use std::sync::Arc;
3334

@@ -157,6 +158,38 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
157158
fn output_ordering(&self, _input: &[ExprProperties]) -> Result<SortProperties> {
158159
Ok(SortProperties::Unordered)
159160
}
161+
162+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
163+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
164+
return false;
165+
};
166+
let Self {
167+
name,
168+
expr,
169+
signature,
170+
return_type,
171+
} = self;
172+
name == &other.name
173+
&& expr == &other.expr
174+
&& signature == &other.signature
175+
&& return_type == &other.return_type
176+
}
177+
178+
fn hash_value(&self) -> u64 {
179+
let Self {
180+
name,
181+
expr,
182+
signature,
183+
return_type,
184+
} = self;
185+
let mut hasher = DefaultHasher::new();
186+
std::any::type_name::<Self>().hash(&mut hasher);
187+
name.hash(&mut hasher);
188+
expr.hash(&mut hasher);
189+
signature.hash(&mut hasher);
190+
return_type.hash(&mut hasher);
191+
hasher.finish()
192+
}
160193
}
161194

162195
impl ScalarFunctionWrapper {

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,34 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
216216
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
217217
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
218218
}
219+
220+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
221+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
222+
return false;
223+
};
224+
let Self {
225+
name,
226+
signature,
227+
return_type,
228+
} = self;
229+
name == &other.name
230+
&& signature == &other.signature
231+
&& return_type == &other.return_type
232+
}
233+
234+
fn hash_value(&self) -> u64 {
235+
let Self {
236+
name,
237+
signature,
238+
return_type,
239+
} = self;
240+
let mut hasher = DefaultHasher::new();
241+
std::any::type_name::<Self>().hash(&mut hasher);
242+
name.hash(&mut hasher);
243+
signature.hash(&mut hasher);
244+
return_type.hash(&mut hasher);
245+
hasher.finish()
246+
}
219247
}
220248

221249
#[tokio::test]
@@ -556,6 +584,34 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
556584
};
557585
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
558586
}
587+
588+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
589+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
590+
return false;
591+
};
592+
let Self {
593+
name,
594+
signature,
595+
return_type,
596+
} = self;
597+
name == &other.name
598+
&& signature == &other.signature
599+
&& return_type == &other.return_type
600+
}
601+
602+
fn hash_value(&self) -> u64 {
603+
let Self {
604+
name,
605+
signature,
606+
return_type,
607+
} = self;
608+
let mut hasher = DefaultHasher::new();
609+
std::any::type_name::<Self>().hash(&mut hasher);
610+
name.hash(&mut hasher);
611+
signature.hash(&mut hasher);
612+
return_type.hash(&mut hasher);
613+
hasher.finish()
614+
}
559615
}
560616

561617
#[tokio::test]
@@ -977,6 +1033,38 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
9771033
fn aliases(&self) -> &[String] {
9781034
&[]
9791035
}
1036+
1037+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
1038+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
1039+
return false;
1040+
};
1041+
let Self {
1042+
name,
1043+
expr,
1044+
signature,
1045+
return_type,
1046+
} = self;
1047+
name == &other.name
1048+
&& expr == &other.expr
1049+
&& signature == &other.signature
1050+
&& return_type == &other.return_type
1051+
}
1052+
1053+
fn hash_value(&self) -> u64 {
1054+
let Self {
1055+
name,
1056+
expr,
1057+
signature,
1058+
return_type,
1059+
} = self;
1060+
let mut hasher = DefaultHasher::new();
1061+
std::any::type_name::<Self>().hash(&mut hasher);
1062+
name.hash(&mut hasher);
1063+
expr.hash(&mut hasher);
1064+
signature.hash(&mut hasher);
1065+
return_type.hash(&mut hasher);
1066+
hasher.finish()
1067+
}
9801068
}
9811069

9821070
impl ScalarFunctionWrapper {

datafusion/doc/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
/// thus all text should be in English.
4040
///
4141
/// [SQL function documentation]: https://datafusion.apache.org/user-guide/sql/index.html
42-
#[derive(Debug, Clone)]
42+
#[derive(Debug, Clone, PartialEq, Hash)]
4343
pub struct Documentation {
4444
/// The section in the documentation where the UDF will be documented
4545
pub doc_section: DocSection,
@@ -158,7 +158,7 @@ impl Documentation {
158158
}
159159
}
160160

161-
#[derive(Debug, Clone, PartialEq)]
161+
#[derive(Debug, Clone, PartialEq, Hash)]
162162
pub struct DocSection {
163163
/// True to include this doc section in the public
164164
/// documentation, false otherwise

datafusion/expr/src/expr_fn.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
4444
use sqlparser::ast::NullTreatment;
4545
use std::any::Any;
4646
use std::fmt::Debug;
47+
use std::hash::{DefaultHasher, Hash, Hasher};
4748
use std::ops::Not;
4849
use std::sync::Arc;
4950

@@ -474,6 +475,38 @@ impl ScalarUDFImpl for SimpleScalarUDF {
474475
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
475476
(self.fun)(&args.args)
476477
}
478+
479+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
480+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
481+
return false;
482+
};
483+
let Self {
484+
name,
485+
signature,
486+
return_type,
487+
fun,
488+
} = self;
489+
name == &other.name
490+
&& signature == &other.signature
491+
&& return_type == &other.return_type
492+
&& Arc::ptr_eq(fun, &other.fun)
493+
}
494+
495+
fn hash_value(&self) -> u64 {
496+
let Self {
497+
name,
498+
signature,
499+
return_type,
500+
fun,
501+
} = self;
502+
let mut hasher = DefaultHasher::new();
503+
std::any::type_name::<Self>().hash(&mut hasher);
504+
name.hash(&mut hasher);
505+
signature.hash(&mut hasher);
506+
return_type.hash(&mut hasher);
507+
Arc::as_ptr(fun).hash(&mut hasher);
508+
hasher.finish()
509+
}
477510
}
478511

479512
/// Creates a new UDAF with a specific signature, state type and return type.
@@ -594,6 +627,42 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
594627
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
595628
Ok(self.state_fields.clone())
596629
}
630+
631+
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
632+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
633+
return false;
634+
};
635+
let Self {
636+
name,
637+
signature,
638+
return_type,
639+
accumulator,
640+
state_fields,
641+
} = self;
642+
name == &other.name
643+
&& signature == &other.signature
644+
&& return_type == &other.return_type
645+
&& Arc::ptr_eq(accumulator, &other.accumulator)
646+
&& state_fields == &other.state_fields
647+
}
648+
649+
fn hash_value(&self) -> u64 {
650+
let Self {
651+
name,
652+
signature,
653+
return_type,
654+
accumulator,
655+
state_fields,
656+
} = self;
657+
let mut hasher = DefaultHasher::new();
658+
std::any::type_name::<Self>().hash(&mut hasher);
659+
name.hash(&mut hasher);
660+
signature.hash(&mut hasher);
661+
return_type.hash(&mut hasher);
662+
Arc::as_ptr(accumulator).hash(&mut hasher);
663+
state_fields.hash(&mut hasher);
664+
hasher.finish()
665+
}
597666
}
598667

599668
/// Creates a new UDWF with a specific signature, state type and return type.
@@ -686,6 +755,41 @@ impl WindowUDFImpl for SimpleWindowUDF {
686755
true,
687756
)))
688757
}
758+
759+
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
760+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
761+
return false;
762+
};
763+
let Self {
764+
name,
765+
signature,
766+
return_type,
767+
partition_evaluator_factory,
768+
} = self;
769+
name == &other.name
770+
&& signature == &other.signature
771+
&& return_type == &other.return_type
772+
&& Arc::ptr_eq(
773+
partition_evaluator_factory,
774+
&other.partition_evaluator_factory,
775+
)
776+
}
777+
778+
fn hash_value(&self) -> u64 {
779+
let Self {
780+
name,
781+
signature,
782+
return_type,
783+
partition_evaluator_factory,
784+
} = self;
785+
let mut hasher = DefaultHasher::new();
786+
std::any::type_name::<Self>().hash(&mut hasher);
787+
name.hash(&mut hasher);
788+
signature.hash(&mut hasher);
789+
return_type.hash(&mut hasher);
790+
Arc::as_ptr(partition_evaluator_factory).hash(&mut hasher);
791+
hasher.finish()
792+
}
689793
}
690794

691795
pub fn interval_year_month_lit(value: &str) -> Expr {

datafusion/expr/src/udaf.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -898,26 +898,35 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
898898
/// Return true if this aggregate UDF is equal to the other.
899899
///
900900
/// Allows customizing the equality of aggregate UDFs.
901+
/// *Must* be implemented explicitly if the UDF type has internal state.
901902
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
902903
///
903904
/// - reflexive: `a.equals(a)`;
904905
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
905906
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
906907
///
907-
/// By default, compares [`Self::name`] and [`Self::signature`].
908+
/// By default, compares type, [`Self::name`], [`Self::aliases`] and [`Self::signature`].
908909
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
909-
self.name() == other.name() && self.signature() == other.signature()
910+
self.as_any().type_id() == other.as_any().type_id()
911+
&& self.name() == other.name()
912+
&& self.aliases() == other.aliases()
913+
&& self.signature() == other.signature()
910914
}
911915

912916
/// Returns a hash value for this aggregate UDF.
913917
///
914-
/// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
915-
/// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
918+
/// Allows customizing the hash code of aggregate UDFs.
919+
/// *Must* be implemented explicitly whenever [`Self::equals`] is implemented.
916920
///
917-
/// By default, hashes [`Self::name`] and [`Self::signature`].
921+
/// Similarly to [`Hash`] and [`Eq`], if [`Self::equals`] returns true for two UDFs,
922+
/// their `hash_value`s must be the same.
923+
///
924+
/// By default, it is consistent with default implementation of [`Self::equals`].
918925
fn hash_value(&self) -> u64 {
919926
let hasher = &mut DefaultHasher::new();
927+
self.as_any().type_id().hash(hasher);
920928
self.name().hash(hasher);
929+
self.aliases().hash(hasher);
921930
self.signature().hash(hasher);
922931
hasher.finish()
923932
}

datafusion/expr/src/udf.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -690,26 +690,35 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
690690
/// Return true if this scalar UDF is equal to the other.
691691
///
692692
/// Allows customizing the equality of scalar UDFs.
693+
/// *Must* be implemented explicitly if the UDF type has internal state.
693694
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
694695
///
695696
/// - reflexive: `a.equals(a)`;
696697
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
697698
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
698699
///
699-
/// By default, compares [`Self::name`] and [`Self::signature`].
700+
/// By default, compares type, [`Self::name`], [`Self::aliases`] and [`Self::signature`].
700701
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
701-
self.name() == other.name() && self.signature() == other.signature()
702+
self.as_any().type_id() == other.as_any().type_id()
703+
&& self.name() == other.name()
704+
&& self.aliases() == other.aliases()
705+
&& self.signature() == other.signature()
702706
}
703707

704708
/// Returns a hash value for this scalar UDF.
705709
///
706-
/// Allows customizing the hash code of scalar UDFs. Similarly to [`Hash`] and [`Eq`],
707-
/// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
710+
/// Allows customizing the hash code of scalar UDFs.
711+
/// *Must* be implemented explicitly whenever [`Self::equals`] is implemented.
708712
///
709-
/// By default, hashes [`Self::name`] and [`Self::signature`].
713+
/// Similarly to [`Hash`] and [`Eq`], if [`Self::equals`] returns true for two UDFs,
714+
/// their `hash_value`s must be the same.
715+
///
716+
/// By default, it is consistent with default implementation of [`Self::equals`].
710717
fn hash_value(&self) -> u64 {
711718
let hasher = &mut DefaultHasher::new();
719+
self.as_any().type_id().hash(hasher);
712720
self.name().hash(hasher);
721+
self.aliases().hash(hasher);
713722
self.signature().hash(hasher);
714723
hasher.finish()
715724
}
@@ -825,6 +834,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
825834

826835
fn hash_value(&self) -> u64 {
827836
let hasher = &mut DefaultHasher::new();
837+
std::any::type_name::<Self>().hash(hasher);
828838
self.inner.hash_value().hash(hasher);
829839
self.aliases.hash(hasher);
830840
hasher.finish()

0 commit comments

Comments
 (0)