Skip to content

Commit

Permalink
feat: support normalize_eq in cse optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuliquan committed Nov 12, 2024
1 parent 3f7bdb5 commit dda0848
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 64 deletions.
76 changes: 45 additions & 31 deletions datafusion/common/src/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,52 +50,62 @@ impl<T: HashNode + ?Sized> HashNode for Arc<T> {
}
}

/// A trait that defines how to normalize a node.
/// The `Normalizeable` trait defines a method to determine whether a node can be normalized.
///
/// This trait is used to normalize nodes before comparing them for CSE. Normalization
/// can be used to ensure that two nodes that are semantically equivalent are considered
/// equal for CSE.
/// For example:`a + b` and `b + a` are semantically equivalent.
pub trait NormalizeNode: Eq {
fn normalize(&self) -> Self;
fn enable_normalized(&self) -> bool;
/// Normalization is the process of converting a node into a canonical form that can be used
/// to compare nodes for equality. This is useful in optimizations like Common Subexpression Elimination (CSE),
/// where semantically equivalent nodes (e.g., `a + b` and `b + a`) should be treated as equal.
pub trait Normalizeable {
fn can_normalize(&self) -> bool;
}

/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing
/// normlized nodes in optimizations like Common Subexpression Elimination (CSE).
///
/// The `normalize_eq` method ensures that two nodes that are semantically equivalent (after normalization)
/// are considered equal in CSE optimization, even if their original forms differ.
///
/// This trait allows for equality comparisons between nodes with equivalent semantics, regardless of their
/// internal representations.
pub trait NormalizeEq: Eq + Normalizeable {
fn normalize_eq(&self, other: &Self) -> bool;
}

/// Identifier that represents a [`TreeNode`] tree.
///
/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and
/// "have no collision (as low as possible)"
#[derive(Debug, Eq)]
struct Identifier<'n, N: NormalizeNode> {
struct Identifier<'n, N: NormalizeEq> {
// Hash of `node` built up incrementally during the first, visiting traversal.
// Its value is not necessarily equal to default hash of the node. E.g. it is not
// equal to `expr.hash()` if the node is `Expr`.
hash: u64,
node: &'n N,
}

impl<N: NormalizeNode> Clone for Identifier<'_, N> {
impl<N: NormalizeEq> Clone for Identifier<'_, N> {
fn clone(&self) -> Self {
*self
}
}
impl<N: NormalizeNode> Copy for Identifier<'_, N> {}
impl<N: NormalizeEq> Copy for Identifier<'_, N> {}

impl<N: NormalizeNode> Hash for Identifier<'_, N> {
impl<N: NormalizeEq> Hash for Identifier<'_, N> {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.hash);
}
}

impl<N: NormalizeNode> PartialEq for Identifier<'_, N> {
impl<N: NormalizeEq> PartialEq for Identifier<'_, N> {
fn eq(&self, other: &Self) -> bool {
self.node.normalize() == other.node.normalize()
self.node.normalize_eq(other.node)
}
}

impl<'n, N> Identifier<'n, N>
where
N: HashNode + NormalizeNode,
N: HashNode + NormalizeEq,
{
fn new(node: &'n N, random_state: &RandomState) -> Self {
let mut hasher = random_state.build_hasher();
Expand Down Expand Up @@ -235,7 +245,7 @@ pub enum FoundCommonNodes<N> {
/// because they should not be recognized as common subtree.
struct CSEVisitor<'a, 'n, N, C>
where
N: NormalizeNode,
N: NormalizeEq,
C: CSEController<Node = N>,
{
/// statistics of [`TreeNode`]s
Expand Down Expand Up @@ -270,7 +280,7 @@ where
/// Record item that used when traversing a [`TreeNode`] tree.
enum VisitRecord<'n, N>
where
N: NormalizeNode,
N: NormalizeEq,
{
/// Marks the beginning of [`TreeNode`]. It contains:
/// - The post-order index assigned during the first, visiting traversal.
Expand All @@ -287,7 +297,7 @@ where

impl<'n, N, C> CSEVisitor<'_, 'n, N, C>
where
N: TreeNode + HashNode + NormalizeNode,
N: TreeNode + HashNode + NormalizeEq,
C: CSEController<Node = N>,
{
/// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
Expand All @@ -304,15 +314,15 @@ where
/// an extra traversal).
fn pop_enter_mark(
&mut self,
enable_normalize: bool,
can_normalize: bool,
) -> (usize, Option<Identifier<'n, N>>, bool) {
let mut node_ids: Vec<Identifier<'n, N>> = vec![];
let mut is_valid = true;

while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(down_index) => {
if enable_normalize {
if can_normalize {
node_ids.sort_by_key(|i| i.hash);
}
let node_id = node_ids
Expand All @@ -332,7 +342,7 @@ where

impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C>
where
N: TreeNode + HashNode + NormalizeNode,
N: TreeNode + HashNode + NormalizeEq,
C: CSEController<Node = N>,
{
type Node = N;
Expand Down Expand Up @@ -374,7 +384,7 @@ where

fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
let (down_index, sub_node_id, sub_node_is_valid) =
self.pop_enter_mark(node.enable_normalized());
self.pop_enter_mark(node.can_normalize());

let node_id = Identifier::new(node, self.random_state).combine(sub_node_id);
let is_valid = C::is_valid(node) && sub_node_is_valid;
Expand Down Expand Up @@ -414,7 +424,7 @@ where
/// replaced [`TreeNode`] tree.
struct CSERewriter<'a, 'n, N, C>
where
N: NormalizeNode,
N: NormalizeEq,
C: CSEController<Node = N>,
{
/// statistics of [`TreeNode`]s
Expand All @@ -435,7 +445,7 @@ where

impl<N, C> TreeNodeRewriter for CSERewriter<'_, '_, N, C>
where
N: TreeNode + NormalizeNode,
N: TreeNode + NormalizeEq,
C: CSEController<Node = N>,
{
type Node = N;
Expand Down Expand Up @@ -509,7 +519,7 @@ pub struct CSE<N, C: CSEController<Node = N>> {

impl<N, C> CSE<N, C>
where
N: TreeNode + HashNode + Clone + NormalizeNode,
N: TreeNode + HashNode + Clone + NormalizeEq,
C: CSEController<Node = N>,
{
pub fn new(controller: C) -> Self {
Expand Down Expand Up @@ -668,7 +678,8 @@ where
mod test {
use crate::alias::AliasGenerator;
use crate::cse::{
CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeNode, CSE,
CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq,
Normalizeable, CSE,
};
use crate::tree_node::tests::TestTreeNode;
use crate::Result;
Expand Down Expand Up @@ -735,15 +746,18 @@ mod test {
}
}

impl NormalizeNode for TestTreeNode<String> {
fn normalize(&self) -> Self {
self.clone()
}
fn enable_normalized(&self) -> bool {
impl Normalizeable for TestTreeNode<String> {
fn can_normalize(&self) -> bool {
false
}
}

impl NormalizeEq for TestTreeNode<String> {
fn normalize_eq(&self, other: &Self) -> bool {
self == other
}
}

#[test]
fn id_array_visitor() -> Result<()> {
let alias_generator = AliasGenerator::new();
Expand Down
66 changes: 33 additions & 33 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
};

use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cse::{HashNode, NormalizeNode};
use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
Expand Down Expand Up @@ -1674,8 +1674,8 @@ impl Expr {
}
}

impl NormalizeNode for Expr {
fn enable_normalized(&self) -> bool {
impl Normalizeable for Expr {
fn can_normalize(&self) -> bool {
#[allow(clippy::match_like_matches_macro)]
match self {
Expr::BinaryExpr(BinaryExpr {
Expand All @@ -1692,18 +1692,29 @@ impl NormalizeNode for Expr {
_ => false,
}
}
}

fn normalize(&self) -> Expr {
match self {
Expr::BinaryExpr(BinaryExpr {
ref left,
ref op,
ref right,
}) => {
let normalized_left = left.normalize();
let normalized_right = right.normalize();
let new_binary = if matches!(
op,
impl NormalizeEq for Expr {
fn normalize_eq(&self, other: &Self) -> bool {
match (self, other) {
(
Expr::BinaryExpr(BinaryExpr {
left: self_left,
op: self_op,
right: self_right,
}),
Expr::BinaryExpr(BinaryExpr {
left: other_left,
op: other_op,
right: other_right,
}),
) => {
if self_op != other_op {
return false;
}

if matches!(
self_op,
Operator::Plus
| Operator::Multiply
| Operator::BitwiseAnd
Expand All @@ -1712,27 +1723,16 @@ impl NormalizeNode for Expr {
| Operator::Eq
| Operator::NotEq
) {
let (l_expr, r_expr) =
if format!("{normalized_left}") < format!("{normalized_right}") {
(normalized_left, normalized_right)
} else {
(normalized_right, normalized_left)
};
BinaryExpr {
left: Box::new(l_expr),
op: *op,
right: Box::new(r_expr),
}
(self_left.normalize_eq(other_left)
&& self_right.normalize_eq(other_right))
|| (self_left.normalize_eq(other_right)
&& self_right.normalize_eq(other_left))
} else {
BinaryExpr {
left: Box::new(normalized_left),
op: *op,
right: Box::new(normalized_right),
}
};
Expr::BinaryExpr(new_binary)
self_left.normalize_eq(other_left)
&& self_right.normalize_eq(other_right)
}
}
other => other.clone(),
(_, _) => self == other,
}
}
}
Expand Down
42 changes: 42 additions & 0 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,48 @@ mod test {
Ok(())
}

#[test]
fn test_normalize_complex_expression() -> Result<()> {
// case1: a + b * c <=> b * c + a
let table_scan = test_table_scan()?;
let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
.eq(lit(30));
let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;

let expected = "Projection: test.a, test.b, test.c\
\n Filter: __common_expr_1 - __common_expr_1 = Int32(30)\
\n Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);

// ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1))
let table_scan = test_table_scan()?;
let expr = (((col("a") + col("b") / col("c")) * col("c"))
/ (col("c") * (col("b") / col("c") + col("a")))
+ col("a"))
.eq(lit(30));
let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
let expected = "Projection: test.a, test.b, test.c\
\n Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)\
\n Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);

// c2 / (c1 + c3) <=> c2 / (c3 + c1)
let table_scan = test_table_scan()?;
let expr = ((col("b") / (col("a") + col("c")))
* (col("b") / (col("c") + col("a"))))
.eq(lit(30));
let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
let expected = "Projection: test.a, test.b, test.c\
\n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\
\n Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";
assert_optimized_plan_eq(expected, plan, None);

Ok(())
}

/// returns a "random" function that is marked volatile (aka each invocation
/// returns a different value)
///
Expand Down

0 comments on commit dda0848

Please sign in to comment.