Skip to content

Commit

Permalink
allow DateTimeField::Custom with EXTRACT in Postgres (#1394)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Aug 26, 2024
1 parent 7282ce2 commit 222b7d1
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 6 deletions.
8 changes: 8 additions & 0 deletions src/dialect/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,12 @@ impl Dialect for GenericDialect {
fn support_map_literal_syntax(&self) -> bool {
true
}

fn allow_extract_custom(&self) -> bool {
true
}

fn allow_extract_single_quotes(&self) -> bool {
true
}
}
10 changes: 10 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,16 @@ pub trait Dialect: Debug + Any {
fn describe_requires_table_keyword(&self) -> bool {
false
}

/// Returns true if this dialect allows the `EXTRACT` function to words other than [`Keyword`].
fn allow_extract_custom(&self) -> bool {
false
}

/// Returns true if this dialect allows the `EXTRACT` function to use single quotes in the part being extracted.
fn allow_extract_single_quotes(&self) -> bool {
false
}
}

/// This represents the operators for which precedence must be defined
Expand Down
8 changes: 8 additions & 0 deletions src/dialect/postgresql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ impl Dialect for PostgreSqlDialect {
Precedence::Or => OR_PREC,
}
}

fn allow_extract_custom(&self) -> bool {
true
}

fn allow_extract_single_quotes(&self) -> bool {
true
}
}

pub fn parse_comment(parser: &mut Parser) -> Result<Statement, ParserError> {
Expand Down
8 changes: 8 additions & 0 deletions src/dialect/snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ impl Dialect for SnowflakeDialect {
fn describe_requires_table_keyword(&self) -> bool {
true
}

fn allow_extract_custom(&self) -> bool {
true
}

fn allow_extract_single_quotes(&self) -> bool {
true
}
}

/// Parse snowflake create table statement.
Expand Down
5 changes: 2 additions & 3 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1970,15 +1970,14 @@ impl<'a> Parser<'a> {
Keyword::TIMEZONE_HOUR => Ok(DateTimeField::TimezoneHour),
Keyword::TIMEZONE_MINUTE => Ok(DateTimeField::TimezoneMinute),
Keyword::TIMEZONE_REGION => Ok(DateTimeField::TimezoneRegion),
_ if dialect_of!(self is SnowflakeDialect | GenericDialect) => {
_ if self.dialect.allow_extract_custom() => {
self.prev_token();
let custom = self.parse_identifier(false)?;
Ok(DateTimeField::Custom(custom))
}
_ => self.expected("date/time field", next_token),
},
Token::SingleQuotedString(_) if dialect_of!(self is SnowflakeDialect | GenericDialect) =>
{
Token::SingleQuotedString(_) if self.dialect.allow_extract_single_quotes() => {
self.prev_token();
let custom = self.parse_identifier(false)?;
Ok(DateTimeField::Custom(custom))
Expand Down
78 changes: 75 additions & 3 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2474,7 +2474,7 @@ fn parse_extract() {
verified_stmt("SELECT EXTRACT(TIMEZONE_REGION FROM d)");
verified_stmt("SELECT EXTRACT(TIME FROM d)");

let dialects = all_dialects_except(|d| d.is::<SnowflakeDialect>() || d.is::<GenericDialect>());
let dialects = all_dialects_except(|d| d.allow_extract_custom());
let res = dialects.parse_sql_statements("SELECT EXTRACT(JIFFY FROM d)");
assert_eq!(
ParserError::ParserError("Expected: date/time field, found: JIFFY".to_string()),
Expand Down Expand Up @@ -2573,7 +2573,7 @@ fn parse_ceil_datetime() {
verified_stmt("SELECT CEIL(d TO SECOND) FROM df");
verified_stmt("SELECT CEIL(d TO MILLISECOND) FROM df");

let dialects = all_dialects_except(|d| d.is::<SnowflakeDialect>() || d.is::<GenericDialect>());
let dialects = all_dialects_except(|d| d.allow_extract_custom());
let res = dialects.parse_sql_statements("SELECT CEIL(d TO JIFFY) FROM df");
assert_eq!(
ParserError::ParserError("Expected: date/time field, found: JIFFY".to_string()),
Expand All @@ -2600,7 +2600,7 @@ fn parse_floor_datetime() {
verified_stmt("SELECT FLOOR(d TO SECOND) FROM df");
verified_stmt("SELECT FLOOR(d TO MILLISECOND) FROM df");

let dialects = all_dialects_except(|d| d.is::<SnowflakeDialect>() || d.is::<GenericDialect>());
let dialects = all_dialects_except(|d| d.allow_extract_custom());
let res = dialects.parse_sql_statements("SELECT FLOOR(d TO JIFFY) FROM df");
assert_eq!(
ParserError::ParserError("Expected: date/time field, found: JIFFY".to_string()),
Expand Down Expand Up @@ -10467,3 +10467,75 @@ fn test_group_by_nothing() {
);
}
}

#[test]
fn test_extract_seconds_ok() {
let dialects = all_dialects_where(|d| d.allow_extract_custom());
let stmt = dialects.verified_expr("EXTRACT(seconds FROM '2 seconds'::INTERVAL)");

assert_eq!(
stmt,
Expr::Extract {
field: DateTimeField::Custom(Ident {
value: "seconds".to_string(),
quote_style: None,
}),
syntax: ExtractSyntax::From,
expr: Box::new(Expr::Cast {
kind: CastKind::DoubleColon,
expr: Box::new(Expr::Value(Value::SingleQuotedString(
"2 seconds".to_string()
))),
data_type: DataType::Interval,
format: None,
}),
}
)
}

#[test]
fn test_extract_seconds_single_quote_ok() {
let dialects = all_dialects_where(|d| d.allow_extract_custom());
let stmt = dialects.verified_expr(r#"EXTRACT('seconds' FROM '2 seconds'::INTERVAL)"#);

assert_eq!(
stmt,
Expr::Extract {
field: DateTimeField::Custom(Ident {
value: "seconds".to_string(),
quote_style: Some('\''),
}),
syntax: ExtractSyntax::From,
expr: Box::new(Expr::Cast {
kind: CastKind::DoubleColon,
expr: Box::new(Expr::Value(Value::SingleQuotedString(
"2 seconds".to_string()
))),
data_type: DataType::Interval,
format: None,
}),
}
)
}

#[test]
fn test_extract_seconds_err() {
let sql = "SELECT EXTRACT(seconds FROM '2 seconds'::INTERVAL)";
let dialects = all_dialects_except(|d| d.allow_extract_custom());
let err = dialects.parse_sql_statements(sql).unwrap_err();
assert_eq!(
err.to_string(),
"sql parser error: Expected: date/time field, found: seconds"
);
}

#[test]
fn test_extract_seconds_single_quote_err() {
let sql = r#"SELECT EXTRACT('seconds' FROM '2 seconds'::INTERVAL)"#;
let dialects = all_dialects_except(|d| d.allow_extract_single_quotes());
let err = dialects.parse_sql_statements(sql).unwrap_err();
assert_eq!(
err.to_string(),
"sql parser error: Expected: date/time field, found: 'seconds'"
);
}

0 comments on commit 222b7d1

Please sign in to comment.