Skip to content

Commit 48492c2

Browse files
authored
refactor(rust): Support for named/anonymous aggregations (#25118)
1 parent 3011cad commit 48492c2

File tree

25 files changed

+254
-14
lines changed

25 files changed

+254
-14
lines changed

crates/polars-arrow/src/array/fixed_size_list/mutable.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,15 @@ impl<M: MutableArray> MutableFixedSizeListArray<M> {
149149
validity.shrink_to_fit()
150150
}
151151
}
152+
153+
pub fn freeze(mut self) -> FixedSizeListArray {
154+
FixedSizeListArray::new(
155+
self.dtype,
156+
self.length,
157+
self.values.as_box(),
158+
self.validity.map(|b| b.freeze()),
159+
)
160+
}
152161
}
153162

154163
impl<M: MutableArray + 'static> MutableArray for MutableFixedSizeListArray<M> {

crates/polars-expr/src/planner.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,5 +634,8 @@ fn create_physical_expr_inner(
634634
false,
635635
)))
636636
},
637+
AnonymousStreamingAgg { .. } => {
638+
polars_bail!(ComputeError: "anonymous agg not supported in in-memory engine")
639+
},
637640
}
638641
}

crates/polars-expr/src/reduce/convert.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,20 @@ pub fn into_reduction(
143143
_ => unreachable!(),
144144
}
145145
},
146+
AExpr::AnonymousStreamingAgg {
147+
input: inner_exprs,
148+
fmt_str: _,
149+
function,
150+
} => {
151+
let ann_agg = function.materialize()?;
152+
assert!(inner_exprs.len() == 1);
153+
let input = inner_exprs[0].node();
154+
let reduction = ann_agg.as_any();
155+
let reduction = reduction
156+
.downcast_ref::<Box<dyn GroupedReduction>>()
157+
.unwrap();
158+
(reduction.new_empty(), input)
159+
},
146160
_ => unreachable!(),
147161
};
148162
Ok(out)

crates/polars-expr/src/reduce/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ pub struct VecGroupedReduction<R: Reducer> {
219219
}
220220

221221
impl<R: Reducer> VecGroupedReduction<R> {
222-
fn new(in_dtype: DataType, reducer: R) -> Self {
222+
pub fn new(in_dtype: DataType, reducer: R) -> Self {
223223
Self {
224224
values: Vec::new(),
225225
evicted_values: Vec::new(),
@@ -486,7 +486,7 @@ where
486486
}
487487

488488
#[derive(Clone)]
489-
struct NullGroupedReduction {
489+
pub struct NullGroupedReduction {
490490
num_groups: IdxSize,
491491
num_evictions: IdxSize,
492492
output: Scalar,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
use std::any::Any;
2+
use std::sync::Arc;
3+
4+
use polars_core::prelude::Field;
5+
use polars_core::schema::Schema;
6+
use polars_error::{PolarsResult, feature_gated};
7+
#[cfg(feature = "ir_serde")]
8+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
9+
10+
use super::SpecialEq;
11+
use crate::dsl::LazySerde;
12+
13+
pub trait AnonymousStreamingAgg: Send + Sync {
14+
fn as_any(&self) -> &dyn Any;
15+
16+
fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field>;
17+
}
18+
19+
pub type OpaqueStreamingAgg = LazySerde<SpecialEq<Arc<dyn AnonymousStreamingAgg>>>;
20+
21+
impl OpaqueStreamingAgg {
22+
pub fn materialize(&self) -> PolarsResult<SpecialEq<Arc<dyn AnonymousStreamingAgg>>> {
23+
match self {
24+
Self::Deserialized(t) => Ok(t.clone()),
25+
Self::Named {
26+
name,
27+
payload,
28+
value,
29+
} => feature_gated!("serde", {
30+
use super::named_serde::NAMED_SERDE_REGISTRY_EXPR;
31+
match value {
32+
Some(v) => Ok(v.clone()),
33+
None => Ok(SpecialEq::new(
34+
NAMED_SERDE_REGISTRY_EXPR
35+
.read()
36+
.unwrap()
37+
.as_ref()
38+
.expect("NAMED EXPR REGISTRY NOT SET")
39+
.get_agg(name, payload.as_ref().unwrap())?
40+
.expect("NAMED AGG NOT FOUND"),
41+
)),
42+
}
43+
}),
44+
Self::Bytes(_b) => {
45+
feature_gated!("serde", {
46+
use crate::dsl::anonymous::serde_expr;
47+
serde_expr::deserialize_anon_agg(_b.as_ref()).map(SpecialEq::new)
48+
})
49+
},
50+
}
51+
}
52+
}
53+
54+
#[cfg(feature = "ir_serde")]
55+
impl Serialize for SpecialEq<Arc<dyn AnonymousStreamingAgg>> {
56+
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
57+
where
58+
S: Serializer,
59+
{
60+
unreachable!("should not be hit")
61+
}
62+
}

crates/polars-plan/src/dsl/expr/anonymous/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ where
139139
}
140140

141141
pub type OpaqueColumnUdf = LazySerde<SpecialEq<Arc<dyn AnonymousColumnsUdf>>>;
142-
pub(crate) fn new_column_udf<F: AnonymousColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {
142+
pub fn new_column_udf<F: AnonymousColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {
143143
LazySerde::Deserialized(SpecialEq::new(Arc::new(func)))
144144
}
145145

crates/polars-plan/src/dsl/expr/anonymous/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
mod agg;
12
mod expr;
3+
4+
pub use agg::{AnonymousStreamingAgg, OpaqueStreamingAgg};
25
pub use expr::*;
6+
37
#[cfg(feature = "dsl-schema")]
48
mod json_schema;
59
#[cfg(feature = "serde")]

crates/polars-plan/src/dsl/expr/anonymous/named_serde.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
11
use std::sync::{Arc, LazyLock, RwLock};
22

3+
use polars_error::PolarsResult;
4+
35
use super::AnonymousColumnsUdf;
6+
use super::agg::AnonymousStreamingAgg;
47

58
// Can be used to have named anonymous functions.
69
// The receiver must have implemented this registry and map the names to the proper UDFs.
710
pub trait ExprRegistry: Sync + Send {
8-
fn get_function(&self, name: &str, payload: &[u8]) -> Option<Arc<dyn AnonymousColumnsUdf>>;
11+
#[allow(unused)]
12+
fn get_function(&self, name: &str, payload: &[u8]) -> Option<Arc<dyn AnonymousColumnsUdf>> {
13+
None
14+
}
15+
16+
#[allow(unused)]
17+
fn get_agg(
18+
&self,
19+
name: &str,
20+
payload: &[u8],
21+
) -> PolarsResult<Option<Arc<dyn AnonymousStreamingAgg>>> {
22+
Ok(None)
23+
}
924
}
1025

1126
pub(super) static NAMED_SERDE_REGISTRY_EXPR: LazyLock<RwLock<Option<Arc<dyn ExprRegistry>>>> =

crates/polars-plan/src/dsl/expr/anonymous/serde_expr.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,35 @@ impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn AnonymousColumnsUdf>> {
119119
}
120120
}
121121

122+
impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn AnonymousStreamingAgg>> {
123+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
124+
where
125+
D: Deserializer<'a>,
126+
{
127+
use serde::de::Error;
128+
deserialize_map_bytes(deserializer, |buf| {
129+
deserialize_anon_agg(&buf)
130+
.map_err(|e| D::Error::custom(format!("{e}")))
131+
.map(SpecialEq::new)
132+
})?
133+
}
134+
}
135+
136+
pub(super) fn deserialize_anon_agg(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousStreamingAgg>> {
137+
if buf.starts_with(NAMED_SERDE_MAGIC_BYTE_MARK) {
138+
let (reg, name, payload) = deserialize_named_registry(buf)?;
139+
140+
if let Some(func) = reg.get_agg(name, payload)? {
141+
Ok(func)
142+
} else {
143+
let msg = "name not found in named serde registry";
144+
polars_bail!(ComputeError: msg)
145+
}
146+
} else {
147+
polars_bail!(ComputeError: "deserialization not supported for this 'opaque' function")
148+
}
149+
}
150+
122151
// Serialize SpecialEq<T>
123152

124153
impl Serialize for SpecialEq<Series> {

crates/polars-plan/src/plans/aexpr/builder.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ impl AExprBuilder {
5555
)
5656
}
5757

58+
pub fn map_as_expr_ir<F: Fn(ExprIR, &mut Arena<AExpr>) -> AExpr>(
59+
self,
60+
mapper: F,
61+
arena: &mut Arena<AExpr>,
62+
) -> Self {
63+
let eir = ExprIR::from_node(self.node, arena);
64+
65+
let ae = mapper(eir, arena);
66+
let node = arena.add(ae);
67+
Self { node }
68+
}
69+
5870
pub fn row_encode_unary(
5971
self,
6072
variant: RowEncodingVariant,

0 commit comments

Comments
 (0)