Skip to content

Commit f57a7a2

Browse files
coastalwhitegfvioli
authored andcommitted
chore: Add AExpr::is_expr_equal_to (pola-rs#23740)
1 parent f1242d4 commit f57a7a2

File tree

3 files changed

+193
-0
lines changed

3 files changed

+193
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
use polars_core::prelude::SortOptions;
2+
use polars_utils::arena::{Arena, Node};
3+
4+
use super::{AExpr, IRAggExpr};
5+
6+
impl AExpr {
7+
pub fn is_expr_equal_to(&self, other: &Self, arena: &Arena<AExpr>) -> bool {
8+
let mut l_stack = Vec::new();
9+
let mut r_stack = Vec::new();
10+
self.is_expr_equal_to_amortized(other, arena, &mut l_stack, &mut r_stack)
11+
}
12+
13+
pub fn is_expr_equal_to_amortized(
14+
&self,
15+
other: &Self,
16+
arena: &Arena<AExpr>,
17+
l_stack: &mut Vec<Node>,
18+
r_stack: &mut Vec<Node>,
19+
) -> bool {
20+
l_stack.clear();
21+
r_stack.clear();
22+
23+
// Top-Level node.
24+
if !self.is_expr_equal_top_level(other) {
25+
return false;
26+
}
27+
self.children_rev(l_stack);
28+
other.children_rev(r_stack);
29+
30+
// Traverse node in N R L order
31+
loop {
32+
assert_eq!(l_stack.len(), r_stack.len());
33+
34+
let (Some(l_node), Some(r_node)) = (l_stack.pop(), r_stack.pop()) else {
35+
break;
36+
};
37+
38+
let l_expr = arena.get(l_node);
39+
let r_expr = arena.get(r_node);
40+
41+
if !l_expr.is_expr_equal_top_level(r_expr) {
42+
return false;
43+
}
44+
l_expr.children_rev(l_stack);
45+
r_expr.children_rev(r_stack);
46+
}
47+
48+
true
49+
}
50+
51+
pub fn is_expr_equal_top_level(&self, other: &Self) -> bool {
52+
if std::mem::discriminant(self) != std::mem::discriminant(other) {
53+
// Fast path: different kind of expression.
54+
return false;
55+
}
56+
57+
use AExpr as E;
58+
59+
// @NOTE: Intentionally written as a match statement over only `self` as it forces the
60+
// match to be exhaustive.
61+
#[rustfmt::skip]
62+
let is_equal = match self {
63+
E::Explode { expr: _, skip_empty: l_skip_empty } => matches!(other, E::Explode { expr: _, skip_empty: r_skip_empty } if l_skip_empty == r_skip_empty),
64+
E::Column(l_name) => matches!(other, E::Column(r_name) if l_name == r_name),
65+
E::Literal(l_lit) => matches!(other, E::Literal(r_lit) if l_lit == r_lit),
66+
E::BinaryExpr { left: _, op: l_op, right: _ } => matches!(other, E::BinaryExpr { left: _, op: r_op, right: _ } if l_op == r_op),
67+
E::Cast { expr: _, dtype: l_dtype, options: l_options } => matches!(other, E::Cast { expr: _, dtype: r_dtype, options: r_options } if l_dtype == r_dtype && l_options == r_options),
68+
E::Sort { expr: _, options: l_options } => matches!(other, E::Sort { expr: _, options: r_options } if l_options == r_options),
69+
E::Gather { expr: _, idx: l_idx, returns_scalar: l_returns_scalar } => matches!(other, E::Gather { expr: _, idx: r_idx, returns_scalar: r_returns_scalar } if l_idx == r_idx && l_returns_scalar == r_returns_scalar),
70+
E::SortBy { expr: _, by: l_by, sort_options: l_sort_options } => matches!(other, E::SortBy { expr: _, by: r_by, sort_options: r_sort_options } if l_by.len() == r_by.len() && l_sort_options == r_sort_options),
71+
E::Agg(l_agg) => matches!(other, E::Agg(r_agg) if l_agg.is_agg_equal_top_level(r_agg)),
72+
E::AnonymousFunction { input: l_input, function: l_function, output_type: l_output_type, options: l_options, fmt_str: l_fmt_str } => matches!(other, E::AnonymousFunction { input: r_input, function: r_function, output_type: r_output_type, options: r_options, fmt_str: r_fmt_str } if l_input.len() == r_input.len() && l_function == r_function && l_output_type == r_output_type && l_options == r_options && l_fmt_str == r_fmt_str),
73+
E::Eval { expr: _, evaluation: _, variant: l_variant } => matches!(other, E::Eval { expr: _, evaluation: _, variant: r_variant } if l_variant == r_variant),
74+
E::Function { input: l_input, function: l_function, options: l_options } => matches!(other, E::Function { input: r_input, function: r_function, options: r_options } if l_input.len() == r_input.len() && l_function == r_function && l_options == r_options),
75+
E::Window { function: _, partition_by: l_partition_by, order_by: l_order_by, options: l_options } => matches!(other, E::Window { function: _, partition_by: r_partition_by, order_by: r_order_by, options: r_options } if l_partition_by.len() == r_partition_by.len() && l_order_by.as_ref().map(|(_, v): &(Node, SortOptions)| v) == r_order_by.as_ref().map(|(_, v): &(Node, SortOptions)| v) && l_options == r_options),
76+
77+
// Discriminant check done above.
78+
E::Filter { input: _, by: _ } |
79+
E::Ternary { predicate: _, truthy: _, falsy: _ } |
80+
E::Slice { input: _, offset: _, length: _ } |
81+
E::Len => true,
82+
};
83+
84+
is_equal
85+
}
86+
}
87+
88+
impl IRAggExpr {
89+
pub fn is_agg_equal_top_level(&self, other: &Self) -> bool {
90+
if std::mem::discriminant(self) != std::mem::discriminant(other) {
91+
// Fast path: different kind of expression.
92+
return false;
93+
}
94+
95+
use IRAggExpr as A;
96+
97+
// @NOTE: Intentionally written as a match statement over only `self` as it forces the
98+
// match to be exhaustive.
99+
#[rustfmt::skip]
100+
let is_equal = match self {
101+
A::Min { input: _, propagate_nans: l_propagate_nans } => matches!(other, A::Min { input: _, propagate_nans: r_propagate_nans } if l_propagate_nans == r_propagate_nans),
102+
A::Max { input: _, propagate_nans: l_propagate_nans } => matches!(other, A::Max { input: _, propagate_nans: r_propagate_nans } if l_propagate_nans == r_propagate_nans),
103+
A::Quantile { expr: _, quantile: _, method: l_method } => matches!(other, A::Quantile { expr: _, quantile: _, method: r_method } if l_method == r_method),
104+
A::Count(_, l_include_nulls) => matches!(other, A::Count(_, r_include_nulls) if l_include_nulls == r_include_nulls),
105+
A::Std(_, l_ddof) => matches!(other, A::Std(_, r_ddof) if l_ddof == r_ddof),
106+
A::Var(_, l_ddof) => matches!(other, A::Var(_, r_ddof) if l_ddof == r_ddof),
107+
108+
// Discriminant check done above.
109+
A::Median(_) |
110+
A::NUnique(_) |
111+
A::First(_) |
112+
A::Last(_) |
113+
A::Mean(_) |
114+
A::Implode(_) |
115+
A::Sum(_) |
116+
A::AggGroups(_) => true,
117+
};
118+
119+
is_equal
120+
}
121+
}

crates/polars-plan/src/plans/aexpr/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod builder;
2+
mod equality;
23
mod evaluate;
34
mod function_expr;
45
#[cfg(feature = "cse")]

crates/polars-plan/src/plans/aexpr/traverse.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ use super::*;
33
impl AExpr {
44
/// Push the inputs of this node to the given container, in reverse order.
55
/// This ensures the primary node responsible for the name is pushed last.
6+
///
7+
/// This is subtlely different from `children_rev` as this only includes the input expressions,
8+
/// not expressions used during evaluation.
69
pub fn inputs_rev<E>(&self, container: &mut E)
710
where
811
E: Extend<Node>,
@@ -73,6 +76,74 @@ impl AExpr {
7376
}
7477
}
7578

79+
/// Push the children of this node to the given container, in reverse order.
80+
/// This ensures the primary node responsible for the name is pushed last.
81+
///
82+
/// This is subtlely different from `input_rev` as this only all expressions included in the
83+
/// expression not only the input expressions,
84+
pub fn children_rev<E: Extend<Node>>(&self, container: &mut E) {
85+
use AExpr::*;
86+
87+
match self {
88+
Column(_) | Literal(_) | Len => {},
89+
BinaryExpr { left, op: _, right } => {
90+
container.extend([*right, *left]);
91+
},
92+
Cast { expr, .. } => container.extend([*expr]),
93+
Sort { expr, .. } => container.extend([*expr]),
94+
Gather { expr, idx, .. } => {
95+
container.extend([*idx, *expr]);
96+
},
97+
SortBy { expr, by, .. } => {
98+
container.extend(by.iter().cloned().rev());
99+
container.extend([*expr]);
100+
},
101+
Filter { input, by } => {
102+
container.extend([*by, *input]);
103+
},
104+
Agg(agg_e) => match agg_e.get_input() {
105+
NodeInputs::Single(node) => container.extend([node]),
106+
NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()),
107+
NodeInputs::Leaf => {},
108+
},
109+
Ternary {
110+
truthy,
111+
falsy,
112+
predicate,
113+
} => {
114+
container.extend([*predicate, *falsy, *truthy]);
115+
},
116+
AnonymousFunction { input, .. } | Function { input, .. } => {
117+
container.extend(input.iter().rev().map(|e| e.node()))
118+
},
119+
Explode { expr: e, .. } => container.extend([*e]),
120+
Window {
121+
function,
122+
partition_by,
123+
order_by,
124+
options: _,
125+
} => {
126+
if let Some((n, _)) = order_by {
127+
container.extend([*n]);
128+
}
129+
container.extend(partition_by.iter().rev().cloned());
130+
container.extend([*function]);
131+
},
132+
Eval {
133+
expr,
134+
evaluation,
135+
variant: _,
136+
} => container.extend([*evaluation, *expr]),
137+
Slice {
138+
input,
139+
offset,
140+
length,
141+
} => {
142+
container.extend([*length, *offset, *input]);
143+
},
144+
}
145+
}
146+
76147
pub fn replace_inputs(mut self, inputs: &[Node]) -> Self {
77148
use AExpr::*;
78149
let input = match &mut self {

0 commit comments

Comments
 (0)