Skip to content

Commit

Permalink
feat: support different USE statement syntaxes
Browse files Browse the repository at this point in the history
  • Loading branch information
kacpermuda committed Aug 16, 2024
1 parent 11a6e6f commit e7690a7
Show file tree
Hide file tree
Showing 10 changed files with 383 additions and 9 deletions.
32 changes: 26 additions & 6 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2515,11 +2515,13 @@ pub enum Statement {
/// Note: this is a MySQL-specific statement.
ShowCollation { filter: Option<ShowStatementFilter> },
/// ```sql
/// USE
/// USE [DATABASE|SCHEMA|CATALOG|...] [<db_name>.<schema_name>|<db_name>|<schema_name>]
/// ```
///
/// Note: This is a MySQL-specific statement.
Use { db_name: Ident },
Use {
db_name: Option<Ident>,
schema_name: Option<Ident>,
keyword: Option<String>,
},
/// ```sql
/// START [ TRANSACTION | WORK ] | START TRANSACTION } ...
/// ```
Expand Down Expand Up @@ -4125,8 +4127,26 @@ impl fmt::Display for Statement {
}
Ok(())
}
Statement::Use { db_name } => {
write!(f, "USE {db_name}")?;
Statement::Use {
db_name,
schema_name,
keyword,
} => {
write!(f, "USE")?;

if let Some(kw) = keyword.as_ref() {
write!(f, " {}", kw)?;
}

if let Some(db_name) = db_name {
write!(f, " {}", db_name)?;
if let Some(schema_name) = schema_name {
write!(f, ".{}", schema_name)?;
}
} else if let Some(schema_name) = schema_name {
write!(f, " {}", schema_name)?;
}

Ok(())
}
Statement::ShowCollation { filter } => {
Expand Down
1 change: 1 addition & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ define_keywords!(
CASCADED,
CASE,
CAST,
CATALOG,
CEIL,
CEILING,
CENTURY,
Expand Down
59 changes: 57 additions & 2 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9225,8 +9225,63 @@ impl<'a> Parser<'a> {
}

pub fn parse_use(&mut self) -> Result<Statement, ParserError> {
let db_name = self.parse_identifier(false)?;
Ok(Statement::Use { db_name })
// What should be treated as keyword in given dialect
let allowed_keywords = if dialect_of!(self is HiveDialect) {
vec![Keyword::DEFAULT]
} else if dialect_of!(self is DatabricksDialect) {
vec![Keyword::CATALOG, Keyword::DATABASE, Keyword::SCHEMA]
} else if dialect_of!(self is SnowflakeDialect) {
vec![Keyword::DATABASE, Keyword::SCHEMA]
} else {
vec![]
};
let parsed_keyword = self.parse_one_of_keywords(&allowed_keywords);

// Hive dialect accepts USE DEFAULT; statement without any db specified
if dialect_of!(self is HiveDialect) && parsed_keyword == Some(Keyword::DEFAULT) {
return Ok(Statement::Use {
db_name: None,
schema_name: None,
keyword: Some("DEFAULT".to_string()),
});
}

// Parse the object name, which might be a single identifier or fully qualified name (e.g., x.y)
let parts = self.parse_object_name(false)?.0;
let (db_name, schema_name) = match parts.len() {
1 => {
// Single identifier found
if dialect_of!(self is DatabricksDialect) {
if parsed_keyword == Some(Keyword::CATALOG) {
// Databricks: CATALOG keyword provided, treat as database name
(Some(parts[0].clone()), None)
} else {
// Databricks: DATABASE, SCHEMA or no keyword provided, treat as schema name
(None, Some(parts[0].clone()))
}
} else if dialect_of!(self is SnowflakeDialect)
&& parsed_keyword == Some(Keyword::SCHEMA)
{
// Snowflake: SCHEMA keyword provided, treat as schema name
(None, Some(parts[0].clone()))
} else {
// Other dialects: treat as database name by default
(Some(parts[0].clone()), None)
}
}
2 => (Some(parts[0].clone()), Some(parts[1].clone())),
_ => {
return Err(ParserError::ParserError(
"Invalid format in the USE statement".to_string(),
))
}
};

Ok(Statement::Use {
db_name,
schema_name,
keyword: parsed_keyword.map(|kw| format!("{:?}", kw)),
})
}

pub fn parse_table_and_joins(&mut self) -> Result<TableWithJoins, ParserError> {
Expand Down
36 changes: 36 additions & 0 deletions tests/sqlparser_clickhouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,42 @@ fn test_prewhere() {
}
}

#[test]
fn parse_use() {
assert_eq!(
clickhouse().verified_stmt("USE mydb"),
Statement::Use {
db_name: Some(Ident::new("mydb")),
schema_name: None,
keyword: None
}
);
assert_eq!(
clickhouse().verified_stmt("USE DATABASE"),
Statement::Use {
db_name: Some(Ident::new("DATABASE")),
schema_name: None,
keyword: None
}
);
assert_eq!(
clickhouse().verified_stmt("USE SCHEMA"),
Statement::Use {
db_name: Some(Ident::new("SCHEMA")),
schema_name: None,
keyword: None
}
);
assert_eq!(
clickhouse().verified_stmt("USE CATALOG"),
Statement::Use {
db_name: Some(Ident::new("CATALOG")),
schema_name: None,
keyword: None
}
);
}

#[test]
fn test_query_with_format_clause() {
let format_options = vec!["TabSeparated", "JSONCompact", "NULL"];
Expand Down
44 changes: 44 additions & 0 deletions tests/sqlparser_databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,47 @@ fn test_values_clause() {
// TODO: support this example from https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-values.html#examples
// databricks().verified_query("VALUES 1, 2, 3");
}

#[test]
fn parse_use() {
assert_eq!(
databricks().verified_stmt("USE my_schema"),
Statement::Use {
db_name: None,
schema_name: Some(Ident::new("my_schema")),
keyword: None
}
);
assert_eq!(
databricks().verified_stmt("USE CATALOG my_catalog"),
Statement::Use {
db_name: Some(Ident::new("my_catalog")),
schema_name: None,
keyword: Some("CATALOG".to_string())
}
);
assert_eq!(
databricks().verified_stmt("USE DATABASE my_schema"),
Statement::Use {
db_name: None,
schema_name: Some(Ident::new("my_schema")),
keyword: Some("DATABASE".to_string())
}
);
assert_eq!(
databricks().verified_stmt("USE SCHEMA my_schema"),
Statement::Use {
db_name: None,
schema_name: Some(Ident::new("my_schema")),
keyword: Some("SCHEMA".to_string())
}
);

let invalid_cases = ["USE SCHEMA", "USE DATABASE", "USE CATALOG"];
for sql in &invalid_cases {
assert_eq!(
databricks().parse_sql_statements(sql).unwrap_err(),
ParserError::ParserError("Expected: identifier, found: EOF".to_string()),
);
}
}
52 changes: 52 additions & 0 deletions tests/sqlparser_duckdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -756,3 +756,55 @@ fn test_duckdb_union_datatype() {
stmt
);
}

#[test]
fn parse_use() {
std::assert_eq!(
duckdb().verified_stmt("USE mydb"),
Statement::Use {
db_name: Some(Ident::new("mydb")),
schema_name: None,
keyword: None
}
);
std::assert_eq!(
duckdb().verified_stmt("USE mydb.my_schema"),
Statement::Use {
db_name: Some(Ident::new("mydb")),
schema_name: Some(Ident::new("my_schema")),
keyword: None
}
);
assert_eq!(
duckdb().verified_stmt("USE DATABASE"),
Statement::Use {
db_name: Some(Ident::new("DATABASE")),
schema_name: None,
keyword: None
}
);
assert_eq!(
duckdb().verified_stmt("USE SCHEMA"),
Statement::Use {
db_name: Some(Ident::new("SCHEMA")),
schema_name: None,
keyword: None
}
);
assert_eq!(
duckdb().verified_stmt("USE CATALOG"),
Statement::Use {
db_name: Some(Ident::new("CATALOG")),
schema_name: None,
keyword: None
}
);
assert_eq!(
duckdb().verified_stmt("USE CATALOG.SCHEMA"),
Statement::Use {
db_name: Some(Ident::new("CATALOG")),
schema_name: Some(Ident::new("SCHEMA")),
keyword: None
}
);
}
44 changes: 44 additions & 0 deletions tests/sqlparser_hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,50 @@ fn parse_delimited_identifiers() {
//TODO verified_stmt(r#"UPDATE foo SET "bar" = 5"#);
}

#[test]
fn parse_use() {
assert_eq!(
hive().verified_stmt("USE mydb"),
Statement::Use {
db_name: Some(Ident::new("mydb")),
schema_name: None,
keyword: None
}
);
assert_eq!(
hive().verified_stmt("USE DEFAULT"),
Statement::Use {
db_name: None,
schema_name: None,
keyword: Some("DEFAULT".to_string()) // Yes, as keyword not db_name
}
);
assert_eq!(
hive().verified_stmt("USE DATABASE"),
Statement::Use {
db_name: Some(Ident::new("DATABASE")),
schema_name: None,
keyword: None
}
);
assert_eq!(
hive().verified_stmt("USE SCHEMA"),
Statement::Use {
db_name: Some(Ident::new("SCHEMA")),
schema_name: None,
keyword: None
}
);
assert_eq!(
hive().verified_stmt("USE CATALOG"),
Statement::Use {
db_name: Some(Ident::new("CATALOG")),
schema_name: None,
keyword: None
}
);
}

fn hive() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(HiveDialect {})],
Expand Down
36 changes: 36 additions & 0 deletions tests/sqlparser_mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,42 @@ fn parse_mssql_declare() {
);
}

#[test]
fn parse_use() {
assert_eq!(
ms().verified_stmt("USE mydb"),
Statement::Use {
db_name: Some(Ident::new("mydb")),
schema_name: None,
keyword: None
}
);
assert_eq!(
ms().verified_stmt("USE DATABASE"),
Statement::Use {
db_name: Some(Ident::new("DATABASE")),
schema_name: None,
keyword: None
}
);
assert_eq!(
ms().verified_stmt("USE SCHEMA"),
Statement::Use {
db_name: Some(Ident::new("SCHEMA")),
schema_name: None,
keyword: None
}
);
assert_eq!(
ms().verified_stmt("USE CATALOG"),
Statement::Use {
db_name: Some(Ident::new("CATALOG")),
schema_name: None,
keyword: None
}
);
}

fn ms() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(MsSqlDialect {})],
Expand Down
28 changes: 27 additions & 1 deletion tests/sqlparser_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,33 @@ fn parse_use() {
assert_eq!(
mysql_and_generic().verified_stmt("USE mydb"),
Statement::Use {
db_name: Ident::new("mydb")
db_name: Some(Ident::new("mydb")),
schema_name: None,
keyword: None
}
);
assert_eq!(
mysql_and_generic().verified_stmt("USE DATABASE"),
Statement::Use {
db_name: Some(Ident::new("DATABASE")),
schema_name: None,
keyword: None
}
);
assert_eq!(
mysql_and_generic().verified_stmt("USE SCHEMA"),
Statement::Use {
db_name: Some(Ident::new("SCHEMA")),
schema_name: None,
keyword: None
}
);
assert_eq!(
mysql_and_generic().verified_stmt("USE CATALOG"),
Statement::Use {
db_name: Some(Ident::new("CATALOG")),
schema_name: None,
keyword: None
}
);
}
Expand Down
Loading

0 comments on commit e7690a7

Please sign in to comment.