Skip to content

Commit

Permalink
Add support for PostgreSQL UNLISTEN syntax and Add support for Post…
Browse files Browse the repository at this point in the history
…gres `LOAD extension` expr (#1531)

Co-authored-by: Ifeanyi Ubah <ify1992@yahoo.com>
  • Loading branch information
wugeer and iffyio authored Nov 19, 2024
1 parent 92be237 commit 73947a5
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 33 deletions.
11 changes: 11 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3340,6 +3340,13 @@ pub enum Statement {
/// See Postgres <https://www.postgresql.org/docs/current/sql-listen.html>
LISTEN { channel: Ident },
/// ```sql
/// UNLISTEN
/// ```
/// stop listening for a notification
///
/// See Postgres <https://www.postgresql.org/docs/current/sql-unlisten.html>
UNLISTEN { channel: Ident },
/// ```sql
/// NOTIFY channel [ , payload ]
/// ```
/// send a notification event together with an optional “payload” string to channel
Expand Down Expand Up @@ -4948,6 +4955,10 @@ impl fmt::Display for Statement {
write!(f, "LISTEN {channel}")?;
Ok(())
}
Statement::UNLISTEN { channel } => {
write!(f, "UNLISTEN {channel}")?;
Ok(())
}
Statement::NOTIFY { channel, payload } => {
write!(f, "NOTIFY {channel}")?;
if let Some(payload) = payload {
Expand Down
9 changes: 2 additions & 7 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,13 +633,8 @@ pub trait Dialect: Debug + Any {
false
}

/// Returns true if the dialect supports the `LISTEN` statement
fn supports_listen(&self) -> bool {
false
}

/// Returns true if the dialect supports the `NOTIFY` statement
fn supports_notify(&self) -> bool {
/// Returns true if the dialect supports the `LISTEN`, `UNLISTEN` and `NOTIFY` statements
fn supports_listen_notify(&self) -> bool {
false
}

Expand Down
12 changes: 7 additions & 5 deletions src/dialect/postgresql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,9 @@ impl Dialect for PostgreSqlDialect {
}

/// see <https://www.postgresql.org/docs/current/sql-listen.html>
fn supports_listen(&self) -> bool {
true
}

/// see <https://www.postgresql.org/docs/current/sql-unlisten.html>
/// see <https://www.postgresql.org/docs/current/sql-notify.html>
fn supports_notify(&self) -> bool {
fn supports_listen_notify(&self) -> bool {
true
}

Expand All @@ -209,6 +206,11 @@ impl Dialect for PostgreSqlDialect {
fn supports_comment_on(&self) -> bool {
true
}

/// See <https://www.postgresql.org/docs/current/sql-load.html>
fn supports_load_extension(&self) -> bool {
true
}
}

pub fn parse_create(parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
Expand Down
1 change: 1 addition & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ define_keywords!(
UNION,
UNIQUE,
UNKNOWN,
UNLISTEN,
UNLOAD,
UNLOCK,
UNLOGGED,
Expand Down
22 changes: 19 additions & 3 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,11 @@ impl<'a> Parser<'a> {
Keyword::EXECUTE | Keyword::EXEC => self.parse_execute(),
Keyword::PREPARE => self.parse_prepare(),
Keyword::MERGE => self.parse_merge(),
// `LISTEN` and `NOTIFY` are Postgres-specific
// `LISTEN`, `UNLISTEN` and `NOTIFY` are Postgres-specific
// syntaxes. They are used for Postgres statement.
Keyword::LISTEN if self.dialect.supports_listen() => self.parse_listen(),
Keyword::NOTIFY if self.dialect.supports_notify() => self.parse_notify(),
Keyword::LISTEN if self.dialect.supports_listen_notify() => self.parse_listen(),
Keyword::UNLISTEN if self.dialect.supports_listen_notify() => self.parse_unlisten(),
Keyword::NOTIFY if self.dialect.supports_listen_notify() => self.parse_notify(),
// `PRAGMA` is sqlite specific https://www.sqlite.org/pragma.html
Keyword::PRAGMA => self.parse_pragma(),
Keyword::UNLOAD => self.parse_unload(),
Expand Down Expand Up @@ -999,6 +1000,21 @@ impl<'a> Parser<'a> {
Ok(Statement::LISTEN { channel })
}

pub fn parse_unlisten(&mut self) -> Result<Statement, ParserError> {
let channel = if self.consume_token(&Token::Mul) {
Ident::new(Expr::Wildcard.to_string())
} else {
match self.parse_identifier(false) {
Ok(expr) => expr,
_ => {
self.prev_token();
return self.expected("wildcard or identifier", self.peek_token());
}
}
};
Ok(Statement::UNLISTEN { channel })
}

pub fn parse_notify(&mut self) -> Result<Statement, ParserError> {
let channel = self.parse_identifier(false)?;
let payload = if self.consume_token(&Token::Comma) {
Expand Down
77 changes: 73 additions & 4 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11595,7 +11595,7 @@ fn test_show_dbs_schemas_tables_views() {

#[test]
fn parse_listen_channel() {
let dialects = all_dialects_where(|d| d.supports_listen());
let dialects = all_dialects_where(|d| d.supports_listen_notify());

match dialects.verified_stmt("LISTEN test1") {
Statement::LISTEN { channel } => {
Expand All @@ -11609,17 +11609,48 @@ fn parse_listen_channel() {
ParserError::ParserError("Expected: identifier, found: *".to_string())
);

let dialects = all_dialects_where(|d| !d.supports_listen());
let dialects = all_dialects_where(|d| !d.supports_listen_notify());

assert_eq!(
dialects.parse_sql_statements("LISTEN test1").unwrap_err(),
ParserError::ParserError("Expected: an SQL statement, found: LISTEN".to_string())
);
}

#[test]
fn parse_unlisten_channel() {
let dialects = all_dialects_where(|d| d.supports_listen_notify());

match dialects.verified_stmt("UNLISTEN test1") {
Statement::UNLISTEN { channel } => {
assert_eq!(Ident::new("test1"), channel);
}
_ => unreachable!(),
};

match dialects.verified_stmt("UNLISTEN *") {
Statement::UNLISTEN { channel } => {
assert_eq!(Ident::new("*"), channel);
}
_ => unreachable!(),
};

assert_eq!(
dialects.parse_sql_statements("UNLISTEN +").unwrap_err(),
ParserError::ParserError("Expected: wildcard or identifier, found: +".to_string())
);

let dialects = all_dialects_where(|d| !d.supports_listen_notify());

assert_eq!(
dialects.parse_sql_statements("UNLISTEN test1").unwrap_err(),
ParserError::ParserError("Expected: an SQL statement, found: UNLISTEN".to_string())
);
}

#[test]
fn parse_notify_channel() {
let dialects = all_dialects_where(|d| d.supports_notify());
let dialects = all_dialects_where(|d| d.supports_listen_notify());

match dialects.verified_stmt("NOTIFY test1") {
Statement::NOTIFY { channel, payload } => {
Expand Down Expand Up @@ -11655,7 +11686,7 @@ fn parse_notify_channel() {
"NOTIFY test1",
"NOTIFY test1, 'this is a test notification'",
];
let dialects = all_dialects_where(|d| !d.supports_notify());
let dialects = all_dialects_where(|d| !d.supports_listen_notify());

for &sql in &sql_statements {
assert_eq!(
Expand Down Expand Up @@ -11864,6 +11895,44 @@ fn parse_load_data() {
);
}

#[test]
fn test_load_extension() {
let dialects = all_dialects_where(|d| d.supports_load_extension());
let not_supports_load_extension_dialects = all_dialects_where(|d| !d.supports_load_extension());
let sql = "LOAD my_extension";

match dialects.verified_stmt(sql) {
Statement::Load { extension_name } => {
assert_eq!(Ident::new("my_extension"), extension_name);
}
_ => unreachable!(),
};

assert_eq!(
not_supports_load_extension_dialects
.parse_sql_statements(sql)
.unwrap_err(),
ParserError::ParserError(
"Expected: `DATA` or an extension name after `LOAD`, found: my_extension".to_string()
)
);

let sql = "LOAD 'filename'";

match dialects.verified_stmt(sql) {
Statement::Load { extension_name } => {
assert_eq!(
Ident {
value: "filename".to_string(),
quote_style: Some('\'')
},
extension_name
);
}
_ => unreachable!(),
};
}

#[test]
fn test_select_top() {
let dialects = all_dialects_where(|d| d.supports_top_before_distinct());
Expand Down
14 changes: 0 additions & 14 deletions tests/sqlparser_duckdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,20 +359,6 @@ fn test_duckdb_install() {
);
}

#[test]
fn test_duckdb_load_extension() {
let stmt = duckdb().verified_stmt("LOAD my_extension");
assert_eq!(
Statement::Load {
extension_name: Ident {
value: "my_extension".to_string(),
quote_style: None
}
},
stmt
);
}

#[test]
fn test_duckdb_struct_literal() {
//struct literal syntax https://duckdb.org/docs/sql/data_types/struct#creating-structs
Expand Down

0 comments on commit 73947a5

Please sign in to comment.