Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: sqlparse fallback for formatting queries #30578

Merged
merged 2 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 88 additions & 20 deletions superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from typing import Any, Generic, TypeVar

import sqlglot
import sqlparse
from deprecation import deprecated
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError
Expand Down Expand Up @@ -138,24 +140,22 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
Base class for SQL statements.

The class can be instantiated with a string representation of the script or, for
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
which will split a script in multiple already parsed statements.
The class should be instantiated with a string representation of the script and, for
efficiency reasons, optionally with a pre-parsed AST. This is useful with
`sqlglot.parse`, which will split a script in multiple already parsed statements.

The `engine` parameters comes from the `engine` attribute in a Superset DB engine
spec.
"""

def __init__(
self,
statement: str | InternalRepresentation,
statement: str,
engine: str,
ast: InternalRepresentation | None = None,
):
self._parsed: InternalRepresentation = (
self._parse_statement(statement, engine)
if isinstance(statement, str)
else statement
)
self._sql = statement
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR the Statement classes didn't store a reference to the original SQL, only the AST. So I had to modify it here.

self._parsed = ast or self._parse_statement(statement, engine)
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)

Expand Down Expand Up @@ -239,11 +239,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):

def __init__(
self,
statement: str | exp.Expression,
statement: str,
engine: str,
ast: exp.Expression | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine)
super().__init__(statement, engine, ast)

@classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
Expand Down Expand Up @@ -275,11 +276,47 @@ def split_script(
script: str,
engine: str,
) -> list[SQLStatement]:
return [
cls(statement, engine)
for statement in cls._parse(script, engine)
if statement
]
if engine in SQLGLOT_DIALECTS:
try:
return [
cls(ast.sql(), engine, ast)
for ast in cls._parse(script, engine)
if ast
]
except ValueError:
# `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES
# FROM`). In this case, we rely on the tokenizer to generate the
# statements.
pass

# When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly
# generate the SQL of each statement, so we tokenize the script and split it
# based on the location of semi-colons.
statements = []
start = 0
remainder = script

try:
tokens = sqlglot.tokenize(script)
except sqlglot.errors.TokenError as ex:
raise SupersetParseError(
script,
engine,
message="Unable to tokenize script",
) from ex

for token in tokens:
if token.token_type == sqlglot.TokenType.SEMICOLON:
statement, start = script[start : token.start], token.end + 1
ast = cls._parse(statement, engine)[0]
statements.append(cls(statement.strip(), engine, ast))
remainder = script[start:]

if remainder.strip():
ast = cls._parse(remainder, engine)[0]
statements.append(cls(remainder.strip(), engine, ast))

return statements

@classmethod
def _parse_statement(
Expand Down Expand Up @@ -349,8 +386,34 @@ def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
if self._dialect:
try:
write = Dialect.get_or_raise(self._dialect)
return write.generate(
self._parsed,
copy=False,
comments=comments,
pretty=True,
)
except ValueError:
pass

return self._fallback_formatting()

@deprecated(deprecated_in="4.0", removed_in="5.0")
def _fallback_formatting(self) -> str:
"""
Format SQL without a specific dialect.

Reformatting SQL using the generic sqlglot dialect is known to break queries.
For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which
breaks the query for Firebolt. To avoid this, we use sqlparse for formatting
when the dialect is not known.

In 5.0 we should remove `sqlparse`, and the method should return the query
unmodified.
"""
return sqlparse.format(self._sql, reindent=True, keyword_case="upper")

def get_settings(self) -> dict[str, str | bool]:
"""
Expand Down Expand Up @@ -456,7 +519,9 @@ def split_script(
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
for more information.
"""
return [cls(statement, engine) for statement in split_kql(script)]
return [
cls(statement, engine, statement.strip()) for statement in split_kql(script)
]

@classmethod
def _parse_statement(
Expand Down Expand Up @@ -498,7 +563,7 @@ def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
return self._parsed
return self._sql.strip()

def get_settings(self) -> dict[str, str | bool]:
"""
Expand Down Expand Up @@ -548,6 +613,9 @@ def __init__(
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL script.

Note that even though KQL is very different from SQL, multiple statements are
still separated by semi-colons.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/sql_lab/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_format_sql_request(self):
"/api/v1/sqllab/format_sql/",
json=data,
)
success_resp = {"result": "SELECT\n 1\nFROM my_table"}
success_resp = {"result": "SELECT 1\nFROM my_table"}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
assert rv.status_code == 200
Expand Down
16 changes: 2 additions & 14 deletions tests/unit_tests/db_engine_specs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,7 @@ class NoLimitDBEngineSpec(BaseEngineSpec):
latest_partition=False,
cols=cols,
)
assert (
sql
== """SELECT
a
FROM my_table
LIMIT ?
OFFSET ?"""
)
assert sql == "SELECT a\nFROM my_table\nLIMIT ?\nOFFSET ?"

sql = NoLimitDBEngineSpec.select_star(
database=database,
Expand All @@ -260,12 +253,7 @@ class NoLimitDBEngineSpec(BaseEngineSpec):
latest_partition=False,
cols=cols,
)
assert (
sql
== """SELECT
a
FROM my_table"""
)
assert sql == "SELECT a\nFROM my_table"


def test_extra_table_metadata(mocker: MockerFixture) -> None:
Expand Down
34 changes: 34 additions & 0 deletions tests/unit_tests/sql/parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,40 @@ def test_extract_tables_show_tables_from() -> None:
)


def test_format_show_tables() -> None:
"""
Test format when `ast.sql()` raises an exception.

In that case sqlparse should be used instead.
"""
assert (
SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
== "SHOW TABLES FROM s1 LIKE '%order%'"
)


def test_format_no_dialect() -> None:
"""
Test format with an engine that has no corresponding dialect.
"""
assert (
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "firebolt").format()
== "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)"
)


def test_split_no_dialect() -> None:
"""
Test the statement split when the engine has no corresponding dialect.
"""
sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo"
statements = SQLScript(sql, "firebolt").statements
assert len(statements) == 3
assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
assert statements[1]._sql == "SELECT * FROM t"
assert statements[2]._sql == "SELECT foo"


def test_extract_tables_show_columns_from() -> None:
"""
Test `SHOW COLUMNS FROM`.
Expand Down
Loading