Skip to content

Commit

Permalink
feat: Support SQL filter clause for aggregate expressions, add SQL di…
Browse files Browse the repository at this point in the history
…alect support (#5868)
  • Loading branch information
yjshen authored Apr 11, 2023
1 parent aed319c commit dafe997
Show file tree
Hide file tree
Showing 22 changed files with 548 additions and 102 deletions.
4 changes: 4 additions & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ config_namespace! {
/// When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted)
pub enable_ident_normalization: bool, default = true

/// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic,
/// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi.
pub dialect: String, default = "generic".to_string()

}
}

Expand Down
35 changes: 33 additions & 2 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ use datafusion_sql::{
planner::{ContextProvider, SqlToRel},
};
use parquet::file::properties::WriterProperties;
use sqlparser::dialect::{
AnsiDialect, BigQueryDialect, ClickHouseDialect, Dialect, GenericDialect,
HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, RedshiftSqlDialect,
SQLiteDialect, SnowflakeDialect,
};
use url::Url;

use crate::catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA};
Expand Down Expand Up @@ -1500,8 +1505,10 @@ impl SessionState {
pub fn sql_to_statement(
&self,
sql: &str,
dialect: &str,
) -> Result<datafusion_sql::parser::Statement> {
let mut statements = DFParser::parse_sql(sql)?;
let dialect = create_dialect_from_str(dialect)?;
let mut statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?;
if statements.len() > 1 {
return Err(DataFusionError::NotImplemented(
"The context currently only supports a single SQL statement".to_string(),
Expand Down Expand Up @@ -1629,7 +1636,8 @@ impl SessionState {
///
/// See [`SessionContext::sql`] for a higher-level interface that also handles DDL
pub async fn create_logical_plan(&self, sql: &str) -> Result<LogicalPlan> {
let statement = self.sql_to_statement(sql)?;
let dialect = self.config.options().sql_parser.dialect.as_str();
let statement = self.sql_to_statement(sql, dialect)?;
let plan = self.statement_to_plan(statement).await?;
Ok(plan)
}
Expand Down Expand Up @@ -1838,6 +1846,29 @@ impl From<&SessionState> for TaskContext {
}
}

// TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/848 is released
fn create_dialect_from_str(dialect_name: &str) -> Result<Box<dyn Dialect>> {
match dialect_name.to_lowercase().as_str() {
"generic" => Ok(Box::new(GenericDialect)),
"mysql" => Ok(Box::new(MySqlDialect {})),
"postgresql" | "postgres" => Ok(Box::new(PostgreSqlDialect {})),
"hive" => Ok(Box::new(HiveDialect {})),
"sqlite" => Ok(Box::new(SQLiteDialect {})),
"snowflake" => Ok(Box::new(SnowflakeDialect)),
"redshift" => Ok(Box::new(RedshiftSqlDialect {})),
"mssql" => Ok(Box::new(MsSqlDialect {})),
"clickhouse" => Ok(Box::new(ClickHouseDialect {})),
"bigquery" => Ok(Box::new(BigQueryDialect)),
"ansi" => Ok(Box::new(AnsiDialect {})),
_ => {
Err(DataFusionError::Internal(format!(
"Unsupported SQL dialect: {}. Available dialects: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi.",
dialect_name
)))
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
13 changes: 13 additions & 0 deletions datafusion/core/src/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>>
{
if partial_agg_exec.mode() == &AggregateMode::Partial
&& partial_agg_exec.group_expr().is_empty()
&& partial_agg_exec.filter_expr().iter().all(|e| e.is_none())
{
let stats = partial_agg_exec.input().statistics();
if stats.is_exact {
Expand Down Expand Up @@ -410,6 +411,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
source,
Arc::clone(&schema),
)?;
Expand All @@ -418,6 +420,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
Expand All @@ -438,6 +441,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
source,
Arc::clone(&schema),
)?;
Expand All @@ -446,6 +450,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
Expand All @@ -465,6 +470,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
source,
Arc::clone(&schema),
)?;
Expand All @@ -476,6 +482,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
Arc::new(coalesce),
Arc::clone(&schema),
)?;
Expand All @@ -495,6 +502,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
source,
Arc::clone(&schema),
)?;
Expand All @@ -506,6 +514,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
Arc::new(coalesce),
Arc::clone(&schema),
)?;
Expand Down Expand Up @@ -536,6 +545,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
filter,
Arc::clone(&schema),
)?;
Expand All @@ -544,6 +554,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
Expand Down Expand Up @@ -579,6 +590,7 @@ mod tests {
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
filter,
Arc::clone(&schema),
)?;
Expand All @@ -587,6 +599,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
Expand Down
8 changes: 8 additions & 0 deletions datafusion/core/src/physical_optimizer/dist_enforcement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ fn adjust_input_keys_ordering(
mode,
group_by,
aggr_expr,
filter_expr,
input,
input_schema,
..
Expand All @@ -264,6 +265,7 @@ fn adjust_input_keys_ordering(
&parent_required,
group_by,
aggr_expr,
filter_expr,
input.clone(),
input_schema,
)?),
Expand Down Expand Up @@ -369,6 +371,7 @@ fn reorder_aggregate_keys(
parent_required: &[Arc<dyn PhysicalExpr>],
group_by: &PhysicalGroupBy,
aggr_expr: &[Arc<dyn AggregateExpr>],
filter_expr: &[Option<Arc<dyn PhysicalExpr>>],
agg_input: Arc<dyn ExecutionPlan>,
input_schema: &SchemaRef,
) -> Result<PlanWithKeyRequirements> {
Expand Down Expand Up @@ -398,6 +401,7 @@ fn reorder_aggregate_keys(
mode,
group_by,
aggr_expr,
filter_expr,
input,
input_schema,
..
Expand All @@ -416,6 +420,7 @@ fn reorder_aggregate_keys(
AggregateMode::Partial,
new_partial_group_by,
aggr_expr.clone(),
filter_expr.clone(),
input.clone(),
input_schema.clone(),
)?))
Expand Down Expand Up @@ -446,6 +451,7 @@ fn reorder_aggregate_keys(
AggregateMode::FinalPartitioned,
new_group_by,
aggr_expr.to_vec(),
filter_expr.to_vec(),
partial_agg,
input_schema.clone(),
)?);
Expand Down Expand Up @@ -1067,11 +1073,13 @@ mod tests {
AggregateMode::FinalPartitioned,
final_grouping,
vec![],
vec![],
Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
group_by,
vec![],
vec![],
input,
schema.clone(),
)
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/src/physical_optimizer/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,13 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![],
vec![],
Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![],
vec![],
input,
schema.clone(),
)
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/physical_optimizer/sort_enforcement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2469,6 +2469,7 @@ mod tests {
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![],
vec![],
input,
schema,
)
Expand Down
Loading

0 comments on commit dafe997

Please sign in to comment.