Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simplify code of eliminate_cross_join.rs #7561

Merged
merged 1 commit into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 10 additions & 24 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,31 +959,17 @@ pub fn find_valid_equijoin_key_pair(
return Ok(None);
}

let l_is_left =
check_all_columns_from_schema(&left_using_columns, left_schema.clone())?;
let r_is_right =
check_all_columns_from_schema(&right_using_columns, right_schema.clone())?;

let r_is_left_and_l_is_right = || {
let result =
check_all_columns_from_schema(&right_using_columns, left_schema.clone())?
&& check_all_columns_from_schema(
&left_using_columns,
right_schema.clone(),
)?;

Result::<_>::Ok(result)
};

let join_key_pair = match (l_is_left, r_is_right) {
(true, true) => Some((left_key.clone(), right_key.clone())),
(_, _) if r_is_left_and_l_is_right()? => {
Some((right_key.clone(), left_key.clone()))
}
_ => None,
};
if check_all_columns_from_schema(&left_using_columns, left_schema.clone())?
&& check_all_columns_from_schema(&right_using_columns, right_schema.clone())?
{
return Ok(Some((left_key.clone(), right_key.clone())));
} else if check_all_columns_from_schema(&right_using_columns, left_schema)?
&& check_all_columns_from_schema(&left_using_columns, right_schema)?
{
return Ok(Some((right_key.clone(), left_key.clone())));
}

Ok(join_key_pair)
Ok(None)
}

/// Creates a detailed error message for a function with wrong signature.
Expand Down
93 changes: 38 additions & 55 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use datafusion_expr::logical_plan::{
CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
};
use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
use datafusion_expr::{and, build_join_schema, or, ExprSchemable, Operator};
use datafusion_expr::{build_join_schema, ExprSchemable, Operator};

#[derive(Default)]
pub struct EliminateCrossJoin;
Expand Down Expand Up @@ -61,14 +61,11 @@ impl OptimizerRule for EliminateCrossJoin {
let mut possible_join_keys: Vec<(Expr, Expr)> = vec![];
let mut all_inputs: Vec<LogicalPlan> = vec![];
let did_flat_successfully = match &input {
LogicalPlan::Join(join) if (join.join_type == JoinType::Inner) => {
try_flatten_join_inputs(
&input,
&mut possible_join_keys,
&mut all_inputs,
)?
}
LogicalPlan::CrossJoin(_) => try_flatten_join_inputs(
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
})
| LogicalPlan::CrossJoin(_) => try_flatten_join_inputs(
&input,
&mut possible_join_keys,
&mut all_inputs,
Expand Down Expand Up @@ -164,16 +161,11 @@ fn try_flatten_join_inputs(

for child in children.iter() {
match *child {
LogicalPlan::Join(left_join) => {
if left_join.join_type == JoinType::Inner {
if !try_flatten_join_inputs(child, possible_join_keys, all_inputs)? {
return Ok(false);
}
} else {
all_inputs.push((*child).clone());
}
}
LogicalPlan::CrossJoin(_) => {
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
})
| LogicalPlan::CrossJoin(_) => {
if !try_flatten_join_inputs(child, possible_join_keys, all_inputs)? {
return Ok(false);
}
Expand Down Expand Up @@ -202,13 +194,10 @@ fn find_inner_join(
)?;

// Save join keys
match key_pair {
Some((valid_l, valid_r)) => {
if can_hash(&valid_l.get_type(left_input.schema())?) {
join_keys.push((valid_l, valid_r));
}
if let Some((valid_l, valid_r)) = key_pair {
if can_hash(&valid_l.get_type(left_input.schema())?) {
join_keys.push((valid_l, valid_r));
}
_ => continue,
}
}

Expand Down Expand Up @@ -303,39 +292,33 @@ fn remove_join_expressions(
join_keys: &HashSet<(Expr, Expr)>,
) -> Result<Option<Expr>> {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
Operator::Eq => {
if join_keys.contains(&(*left.clone(), *right.clone()))
|| join_keys.contains(&(*right.clone(), *left.clone()))
{
Ok(None)
} else {
Ok(Some(expr.clone()))
}
}
Operator::And => {
let l = remove_join_expressions(left, join_keys)?;
let r = remove_join_expressions(right, join_keys)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(and(ll, rr))),
(Some(ll), _) => Ok(Some(ll)),
(_, Some(rr)) => Ok(Some(rr)),
_ => Ok(None),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
match op {
Operator::Eq => {
if join_keys.contains(&(*left.clone(), *right.clone()))
|| join_keys.contains(&(*right.clone(), *left.clone()))
{
Ok(None)
} else {
Ok(Some(expr.clone()))
}
}
}
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Operator::Or => {
let l = remove_join_expressions(left, join_keys)?;
let r = remove_join_expressions(right, join_keys)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(or(ll, rr))),
(Some(ll), _) => Ok(Some(ll)),
(_, Some(rr)) => Ok(Some(rr)),
_ => Ok(None),
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Operator::And | Operator::Or => {
let l = remove_join_expressions(left, join_keys)?;
let r = remove_join_expressions(right, join_keys)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(Expr::BinaryExpr(
BinaryExpr::new(Box::new(ll), *op, Box::new(rr)),
))),
(Some(ll), _) => Ok(Some(ll)),
(_, Some(rr)) => Ok(Some(rr)),
_ => Ok(None),
}
}
_ => Ok(Some(expr.clone())),
}
_ => Ok(Some(expr.clone())),
},
}
_ => Ok(Some(expr.clone())),
}
}
Expand Down