Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: sql cross join #3110

Merged
merged 11 commits into from
Oct 28, 2024
Prev Previous commit
Next Next commit
wip: sql cross join
  • Loading branch information
universalmind303 committed Oct 23, 2024
commit 68bb7ced46d141ba3314d827e4b674d824e8f5ea
2 changes: 2 additions & 0 deletions src/common/error/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ pub enum DaftError {
FmtError(#[from] std::fmt::Error),
#[error("DaftError::RegexError {0}")]
RegexError(#[from] regex::Error),
#[error("PlanningError {0}")]
PlanningError(String),
}

impl From<arrow2::error::Error> for DaftError {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Project {
})
}
/// Create a new Projection using the specified output schema
pub fn new_from_schema(input: Arc<LogicalPlan>, schema: SchemaRef) -> Result<Self> {
pub(crate) fn new_from_schema(input: Arc<LogicalPlan>, schema: SchemaRef) -> Result<Self> {
let expr: Vec<ExprRef> = schema
.names()
.into_iter()
Expand Down
13 changes: 6 additions & 7 deletions src/daft-plan/src/logical_optimization/join_key_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl JoinKeySet {
iter: impl IntoIterator<Item = &'a (ExprRef, ExprRef)>,
) -> bool {
let mut inserted = false;
for (left, right) in iter.into_iter() {
for (left, right) in iter {
inserted |= self.insert(left, right);
}
inserted
Expand All @@ -102,19 +102,19 @@ impl JoinKeySet {
/// returns true if any of the pairs were inserted
pub fn insert_all_owned(&mut self, iter: impl IntoIterator<Item = (ExprRef, ExprRef)>) -> bool {
let mut inserted = false;
for (left, right) in iter.into_iter() {
for (left, right) in iter {
inserted |= self.insert_owned(Arc::unwrap_or_clone(left), Arc::unwrap_or_clone(right));
}
inserted
}

/// Inserts any join keys that are common to both `s1` and `s2` into self
pub fn insert_intersection(&mut self, s1: &JoinKeySet, s2: &JoinKeySet) {
pub fn insert_intersection(&mut self, s1: &Self, s2: &Self) {
// note can't use inner.intersection as we need to consider both (l, r)
// and (r, l) in equality
for (left, right) in s1.inner.iter() {
if s2.contains(left, right) {
self.insert(left, right);
for (left, right) in &s1.inner {
if s2.contains(left.as_ref(), right.as_ref()) {
self.insert(left.as_ref(), right.as_ref());
}
}
}
Expand All @@ -140,7 +140,6 @@ impl JoinKeySet {
///
/// This behaves like a `(Expr, Expr)` tuple for hashing and comparison, but
/// avoids copying the values simply to comparing them.

#[derive(Debug, Eq, PartialEq, Hash)]
struct ExprPair<'a>(&'a Expr, &'a Expr);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
/// Heavily inspired by DataFusion's EliminateCrossJoin rule: https://github.com/apache/datafusion/blob/b978cf8236436038a106ed94fb0d7eaa6ba99962/datafusion/optimizer/src/eliminate_cross_join.rs
use std::{collections::HashSet, sync::Arc};

use common_error::DaftResult;
use common_treenode::{DynTreeNode, Transformed, TreeNode, TreeNodeRecursion};
use common_treenode::{Transformed, TreeNode, TreeNodeRecursion};
use daft_core::{
join::{JoinStrategy, JoinType},
join::JoinType,
prelude::{Schema, SchemaRef, TimeUnit},
};
use daft_dsl::{Expr, ExprRef, Operator};
use daft_schema::dtype::DataType;
use indexmap::IndexSet;

use super::OptimizerRule;
use crate::{
logical_ops::{Filter, Join, Project},
logical_optimization::{join_key_set::JoinKeySet, OptimizerConfig},
logical_optimization::join_key_set::JoinKeySet,
LogicalPlan, LogicalPlanRef,
};

Expand All @@ -28,7 +28,6 @@ impl EliminateCrossJoin {

impl OptimizerRule for EliminateCrossJoin {
fn try_optimize(&self, plan: Arc<LogicalPlan>) -> DaftResult<Transformed<Arc<LogicalPlan>>> {
dbg!("EliminateCrossJoin");
let schema = plan.schema();
let mut possible_join_keys = JoinKeySet::new();
let mut all_inputs: Vec<Arc<LogicalPlan>> = vec![];
Expand All @@ -46,8 +45,7 @@ impl OptimizerRule for EliminateCrossJoin {
})
);
if !rewriteable {
todo!()
// return Ok(Transformed::no(Arc::new(filter)));
return rewrite_children(self, Arc::new(LogicalPlan::Filter(filter)));
}
if !can_flatten_join_inputs(filter.input.as_ref()) {
return Ok(Transformed::no(Arc::new(LogicalPlan::Filter(filter))));
Expand Down Expand Up @@ -89,7 +87,7 @@ impl OptimizerRule for EliminateCrossJoin {
}
left = rewrite_children(self, left)?.data;
if schema != left.schema() {
let project = Project::new_from_schema(left, schema.clone())?;
let project = Project::new_from_schema(left, schema)?;

left = Arc::new(LogicalPlan::Project(project));
}
Expand Down Expand Up @@ -130,7 +128,7 @@ fn flatten_join_inputs(
) -> DaftResult<()> {
match plan {
LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
let keys = join.left_on.into_iter().zip(join.right_on.into_iter());
let keys = join.left_on.into_iter().zip(join.right_on);
possible_join_keys.insert_all_owned(keys);
flatten_join_inputs(
Arc::unwrap_or_clone(join.left),
Expand Down Expand Up @@ -162,16 +160,15 @@ fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
};

for child in plan.children() {
match child {
if matches!(
child,
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
}) => {
if !can_flatten_join_inputs(child) {
return false;
}
}
_ => (),
})
) && !can_flatten_join_inputs(child)
{
return false;
}
}
true
Expand All @@ -187,17 +184,17 @@ fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
}
Operator::And => {
extract_possible_join_keys(left, join_keys);
extract_possible_join_keys(right, join_keys)
extract_possible_join_keys(right, join_keys);
}
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
// Fix for join predicates from inside of OR expr also pulled up properly.
Operator::Or => {
let mut left_join_keys = JoinKeySet::new();
let mut right_join_keys = JoinKeySet::new();

extract_possible_join_keys(left, &mut left_join_keys);
extract_possible_join_keys(right, &mut right_join_keys);

join_keys.insert_intersection(&left_join_keys, &right_join_keys)
join_keys.insert_intersection(&left_join_keys, &right_join_keys);
}
_ => (),
};
Expand All @@ -219,7 +216,7 @@ fn remove_join_expressions(expr: ExprRef, join_keys: &JoinKeySet) -> Option<Expr
// was a join key, so remove it
None
}
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
// Fix for join predicates from inside of OR expr also pulled up properly.
Expr::BinaryOp { left, op, right } if op == Operator::And => {
let l = remove_join_expressions(left, join_keys);
let r = remove_join_expressions(right, join_keys);
Expand Down Expand Up @@ -318,7 +315,7 @@ fn find_inner_join(
.schema()
.non_distinct_union(right.schema().as_ref());

return Ok(LogicalPlan::Join(Join {
Ok(LogicalPlan::Join(Join {
left: left_input,
right,
left_on: vec![],
Expand All @@ -327,16 +324,16 @@ fn find_inner_join(
join_strategy: None,
output_schema: Arc::new(join_schema),
})
.arced());
.arced())
}

/// Check whether all columns are from the schema.
pub fn check_all_columns_from_schema(
columns: &HashSet<Arc<str>>,
schema: &Schema,
) -> DaftResult<bool> {
for col in columns.iter() {
let exist = schema.get_index(&col).is_ok();
for col in columns {
let exist = schema.get_index(col).is_ok();

if !exist {
return Ok(false);
Expand Down Expand Up @@ -372,11 +369,11 @@ pub fn find_valid_equijoin_key_pair(
if check_all_columns_from_schema(&left_using_columns, &left_schema)?
&& check_all_columns_from_schema(&right_using_columns, &right_schema)?
{
return Ok(Some((left_key.clone(), right_key.clone())));
return Ok(Some((left_key, right_key)));
} else if check_all_columns_from_schema(&right_using_columns, &left_schema)?
&& check_all_columns_from_schema(&left_using_columns, &right_schema)?
{
return Ok(Some((right_key.clone(), left_key.clone())));
return Ok(Some((right_key, left_key)));
}

Ok(None)
Expand Down
4 changes: 2 additions & 2 deletions src/daft-plan/src/logical_optimization/rules/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
mod drop_repartition;
mod eliminate_cross_join;
mod push_down_filter;
mod push_down_limit;
mod push_down_projection;
mod rule;
mod split_actor_pool_projects;
mod eliminate_cross_join;

pub use drop_repartition::DropRepartition;
pub use eliminate_cross_join::EliminateCrossJoin;
pub use push_down_filter::PushDownFilter;
pub use push_down_limit::PushDownLimit;
pub use push_down_projection::PushDownProjection;
pub use rule::OptimizerRule;
pub use split_actor_pool_projects::SplitActorPoolProjects;
pub use eliminate_cross_join::EliminateCrossJoin;
6 changes: 4 additions & 2 deletions src/daft-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
};

use common_daft_config::DaftExecutionConfig;
use common_error::DaftResult;
use common_error::{DaftError, DaftResult};
use common_file_formats::FileFormat;
use daft_core::prelude::*;
use daft_dsl::{
Expand Down Expand Up @@ -429,7 +429,9 @@ pub(super) fn translate_single_logical_node(
..
}) => {
if left_on.is_empty() && right_on.is_empty() && join_type == &JoinType::Inner {
todo!("Cross join not yet implemented")
return Err(DaftError::PlanningError(
"Cross join is not supported".to_string(),
));
}
let mut right_physical = physical_children.pop().expect("requires 1 inputs");
let mut left_physical = physical_children.pop().expect("requires 2 inputs");
Expand Down
10 changes: 8 additions & 2 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::{
error::{PlannerError, SQLPlannerResult},
invalid_operation_err, table_not_found_err, unsupported_sql_err,
};

/// A named logical plan
/// This is used to keep track of the table name associated with a logical plan while planning a SQL query
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -297,8 +298,11 @@ impl SQLPlanner {

let first = from_iter.next().unwrap();
let mut rel = self.plan_relation(&first.relation)?;
self.table_map.insert(rel.get_name(), rel.clone());
for tbl in from_iter {
let right = self.plan_relation(&tbl.relation)?;
self.table_map.insert(right.get_name(), right.clone());
let right_join_prefix = Some(format!("{}.", right.get_name()));

rel.inner = rel.inner.join(
right.inner,
Expand All @@ -307,13 +311,13 @@ impl SQLPlanner {
JoinType::Inner,
None,
None,
None,
right_join_prefix.as_deref(),
)?;
}
return Ok(rel);
}

let from = from.into_iter().next().unwrap();
let from = from.iter().next().unwrap();

fn collect_compound_identifiers(
left: &[Ident],
Expand Down Expand Up @@ -515,6 +519,7 @@ impl SQLPlanner {

let root = idents.next().unwrap();
let root = ident_to_str(root);

let current_relation = match self.table_map.get(&root) {
Some(rel) => rel,
None => {
Expand All @@ -539,6 +544,7 @@ impl SQLPlanner {
// If duplicate columns are present in the schema, it adds the table name as a prefix. (df.column_name)
// So we first check if the prefixed column name is present in the schema.
let current_schema = self.relation_opt().unwrap().inner.schema();

let f = current_schema.get_field(&ident_str).ok();
if let Some(field) = f {
Ok(vec![col(field.name.clone())])
Expand Down
Loading