Skip to content

Commit

Permalink
Add dialect param to use double precision for float64 in Postgres (#1…
Browse files Browse the repository at this point in the history
…1495)

* Add dialect param to use double precision for float64 in Postgres

* return ast data type instead of bool

* Fix errors in merging

* fix
  • Loading branch information
Sevenannn authored Jul 19, 2024
1 parent ebe61ba commit 827d0e3
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
28 changes: 28 additions & 0 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,18 @@ pub trait Dialect: Send + Sync {
IntervalStyle::PostgresVerbose
}

// Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE?
// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE
fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
sqlparser::ast::DataType::Double
}

// The SQL type to use for Arrow Utf8 unparsing
// Most dialects use VARCHAR, but some, like MySQL, require CHAR
fn utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Varchar(None)
}

// The SQL type to use for Arrow LargeUtf8 unparsing
// Most dialects use TEXT, but some, like MySQL, require CHAR
fn large_utf8_cast_dtype(&self) -> ast::DataType {
Expand Down Expand Up @@ -98,6 +105,10 @@ impl Dialect for PostgreSqlDialect {
fn interval_style(&self) -> IntervalStyle {
IntervalStyle::PostgresVerbose
}

fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
sqlparser::ast::DataType::DoublePrecision
}
}

pub struct MySqlDialect {}
Expand Down Expand Up @@ -137,6 +148,7 @@ pub struct CustomDialect {
supports_nulls_first_in_sort: bool,
use_timestamp_for_date64: bool,
interval_style: IntervalStyle,
float64_ast_dtype: sqlparser::ast::DataType,
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
}
Expand All @@ -148,6 +160,7 @@ impl Default for CustomDialect {
supports_nulls_first_in_sort: true,
use_timestamp_for_date64: false,
interval_style: IntervalStyle::SQLStandard,
float64_ast_dtype: sqlparser::ast::DataType::Double,
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
}
Expand Down Expand Up @@ -182,6 +195,10 @@ impl Dialect for CustomDialect {
self.interval_style
}

fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
self.float64_ast_dtype.clone()
}

fn utf8_cast_dtype(&self) -> ast::DataType {
self.utf8_cast_dtype.clone()
}
Expand Down Expand Up @@ -210,6 +227,7 @@ pub struct CustomDialectBuilder {
supports_nulls_first_in_sort: bool,
use_timestamp_for_date64: bool,
interval_style: IntervalStyle,
float64_ast_dtype: sqlparser::ast::DataType,
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
}
Expand All @@ -227,6 +245,7 @@ impl CustomDialectBuilder {
supports_nulls_first_in_sort: true,
use_timestamp_for_date64: false,
interval_style: IntervalStyle::PostgresVerbose,
float64_ast_dtype: sqlparser::ast::DataType::Double,
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
}
Expand All @@ -238,6 +257,7 @@ impl CustomDialectBuilder {
supports_nulls_first_in_sort: self.supports_nulls_first_in_sort,
use_timestamp_for_date64: self.use_timestamp_for_date64,
interval_style: self.interval_style,
float64_ast_dtype: self.float64_ast_dtype,
utf8_cast_dtype: self.utf8_cast_dtype,
large_utf8_cast_dtype: self.large_utf8_cast_dtype,
}
Expand Down Expand Up @@ -273,6 +293,14 @@ impl CustomDialectBuilder {
self
}

pub fn with_float64_ast_dtype(
mut self,
float64_ast_dtype: sqlparser::ast::DataType,
) -> Self {
self.float64_ast_dtype = float64_ast_dtype;
self
}

pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self {
self.utf8_cast_dtype = utf8_cast_dtype;
self
Expand Down
30 changes: 29 additions & 1 deletion datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ impl Unparser<'_> {
not_impl_err!("Unsupported DataType: conversion: {data_type:?}")
}
DataType::Float32 => Ok(ast::DataType::Float(None)),
DataType::Float64 => Ok(ast::DataType::Double),
DataType::Float64 => Ok(self.dialect.float64_ast_dtype()),
DataType::Timestamp(_, tz) => {
let tz_info = match tz {
Some(_) => TimezoneInfo::WithTimeZone,
Expand Down Expand Up @@ -1822,6 +1822,34 @@ mod tests {
Ok(())
}

#[test]
fn custom_dialect_float64_ast_dtype() -> Result<()> {
for (float64_ast_dtype, identifier) in [
(sqlparser::ast::DataType::Double, "DOUBLE"),
(
sqlparser::ast::DataType::DoublePrecision,
"DOUBLE PRECISION",
),
] {
let dialect = CustomDialectBuilder::new()
.with_float64_ast_dtype(float64_ast_dtype)
.build();
let unparser = Unparser::new(&dialect);

let expr = Expr::Cast(Cast {
expr: Box::new(col("a")),
data_type: DataType::Float64,
});
let ast = unparser.expr_to_sql(&expr)?;

let actual = format!("{}", ast);

let expected = format!(r#"CAST(a AS {identifier})"#);
assert_eq!(actual, expected);
}
Ok(())
}

#[test]
fn customer_dialect_support_nulls_first_in_ort() -> Result<()> {
let tests: Vec<(Expr, &str, bool)> = vec![
Expand Down

0 comments on commit 827d0e3

Please sign in to comment.