diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs index 2777dfb02..211abd976 100644 --- a/src/dialect/generic.rs +++ b/src/dialect/generic.rs @@ -78,4 +78,12 @@ impl Dialect for GenericDialect { fn support_map_literal_syntax(&self) -> bool { true } + + fn allow_extract_custom(&self) -> bool { + true + } + + fn allow_extract_single_quotes(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index df3022adc..143c8e1c9 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -499,6 +499,16 @@ pub trait Dialect: Debug + Any { fn describe_requires_table_keyword(&self) -> bool { false } + + /// Returns true if this dialect allows the `EXTRACT` function to words other than [`Keyword`]. + fn allow_extract_custom(&self) -> bool { + false + } + + /// Returns true if this dialect allows the `EXTRACT` function to use single quotes in the part being extracted. + fn allow_extract_single_quotes(&self) -> bool { + false + } } /// This represents the operators for which precedence must be defined diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index c25a80f67..8abaa4a5f 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -154,6 +154,14 @@ impl Dialect for PostgreSqlDialect { Precedence::Or => OR_PREC, } } + + fn allow_extract_custom(&self) -> bool { + true + } + + fn allow_extract_single_quotes(&self) -> bool { + true + } } pub fn parse_comment(parser: &mut Parser) -> Result { diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 89508c3bb..4f37004b1 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -158,6 +158,14 @@ impl Dialect for SnowflakeDialect { fn describe_requires_table_keyword(&self) -> bool { true } + + fn allow_extract_custom(&self) -> bool { + true + } + + fn allow_extract_single_quotes(&self) -> bool { + true + } } /// Parse snowflake create table statement. diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 8f8c3f050..302e5e660 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1970,15 +1970,14 @@ impl<'a> Parser<'a> { Keyword::TIMEZONE_HOUR => Ok(DateTimeField::TimezoneHour), Keyword::TIMEZONE_MINUTE => Ok(DateTimeField::TimezoneMinute), Keyword::TIMEZONE_REGION => Ok(DateTimeField::TimezoneRegion), - _ if dialect_of!(self is SnowflakeDialect | GenericDialect) => { + _ if self.dialect.allow_extract_custom() => { self.prev_token(); let custom = self.parse_identifier(false)?; Ok(DateTimeField::Custom(custom)) } _ => self.expected("date/time field", next_token), }, - Token::SingleQuotedString(_) if dialect_of!(self is SnowflakeDialect | GenericDialect) => - { + Token::SingleQuotedString(_) if self.dialect.allow_extract_single_quotes() => { self.prev_token(); let custom = self.parse_identifier(false)?; Ok(DateTimeField::Custom(custom)) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 517e978b4..dd4aad146 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -2474,7 +2474,7 @@ fn parse_extract() { verified_stmt("SELECT EXTRACT(TIMEZONE_REGION FROM d)"); verified_stmt("SELECT EXTRACT(TIME FROM d)"); - let dialects = all_dialects_except(|d| d.is::() || d.is::()); + let dialects = all_dialects_except(|d| d.allow_extract_custom()); let res = dialects.parse_sql_statements("SELECT EXTRACT(JIFFY FROM d)"); assert_eq!( ParserError::ParserError("Expected: date/time field, found: JIFFY".to_string()), @@ -2573,7 +2573,7 @@ fn parse_ceil_datetime() { verified_stmt("SELECT CEIL(d TO SECOND) FROM df"); verified_stmt("SELECT CEIL(d TO MILLISECOND) FROM df"); - let dialects = all_dialects_except(|d| d.is::() || d.is::()); + let dialects = all_dialects_except(|d| d.allow_extract_custom()); let res = dialects.parse_sql_statements("SELECT CEIL(d TO JIFFY) FROM df"); assert_eq!( ParserError::ParserError("Expected: date/time field, found: JIFFY".to_string()), @@ -2600,7 +2600,7 @@ fn parse_floor_datetime() { verified_stmt("SELECT FLOOR(d TO SECOND) FROM df"); verified_stmt("SELECT FLOOR(d TO MILLISECOND) FROM df"); - let dialects = all_dialects_except(|d| d.is::() || d.is::()); + let dialects = all_dialects_except(|d| d.allow_extract_custom()); let res = dialects.parse_sql_statements("SELECT FLOOR(d TO JIFFY) FROM df"); assert_eq!( ParserError::ParserError("Expected: date/time field, found: JIFFY".to_string()), @@ -10467,3 +10467,75 @@ fn test_group_by_nothing() { ); } } + +#[test] +fn test_extract_seconds_ok() { + let dialects = all_dialects_where(|d| d.allow_extract_custom()); + let stmt = dialects.verified_expr("EXTRACT(seconds FROM '2 seconds'::INTERVAL)"); + + assert_eq!( + stmt, + Expr::Extract { + field: DateTimeField::Custom(Ident { + value: "seconds".to_string(), + quote_style: None, + }), + syntax: ExtractSyntax::From, + expr: Box::new(Expr::Cast { + kind: CastKind::DoubleColon, + expr: Box::new(Expr::Value(Value::SingleQuotedString( + "2 seconds".to_string() + ))), + data_type: DataType::Interval, + format: None, + }), + } + ) +} + +#[test] +fn test_extract_seconds_single_quote_ok() { + let dialects = all_dialects_where(|d| d.allow_extract_custom()); + let stmt = dialects.verified_expr(r#"EXTRACT('seconds' FROM '2 seconds'::INTERVAL)"#); + + assert_eq!( + stmt, + Expr::Extract { + field: DateTimeField::Custom(Ident { + value: "seconds".to_string(), + quote_style: Some('\''), + }), + syntax: ExtractSyntax::From, + expr: Box::new(Expr::Cast { + kind: CastKind::DoubleColon, + expr: Box::new(Expr::Value(Value::SingleQuotedString( + "2 seconds".to_string() + ))), + data_type: DataType::Interval, + format: None, + }), + } + ) +} + +#[test] +fn test_extract_seconds_err() { + let sql = "SELECT EXTRACT(seconds FROM '2 seconds'::INTERVAL)"; + let dialects = all_dialects_except(|d| d.allow_extract_custom()); + let err = dialects.parse_sql_statements(sql).unwrap_err(); + assert_eq!( + err.to_string(), + "sql parser error: Expected: date/time field, found: seconds" + ); +} + +#[test] +fn test_extract_seconds_single_quote_err() { + let sql = r#"SELECT EXTRACT('seconds' FROM '2 seconds'::INTERVAL)"#; + let dialects = all_dialects_except(|d| d.allow_extract_single_quotes()); + let err = dialects.parse_sql_statements(sql).unwrap_err(); + assert_eq!( + err.to_string(), + "sql parser error: Expected: date/time field, found: 'seconds'" + ); +}