From 8c376548e3cf6464e5710e80120f82227174fcdc Mon Sep 17 00:00:00 2001 From: Dave McNulla Date: Mon, 31 Jan 2022 19:15:42 -0800 Subject: [PATCH] fix(teradata): LIMIT syntax (#18240) Co-authored-by: Mccush, Jack Co-authored-by: Jack McCush <33156805+mccushjack@users.noreply.github.com> Co-authored-by: Beto Dealmeida Co-authored-by: David McNulla --- docs/installation.rst | 17 +- .../docs/Connecting to Databases/index.mdx | 2 +- .../docs/Connecting to Databases/teradata.mdx | 15 +- setup.py | 2 +- superset/db_engine_specs/teradata.py | 240 +++++++++++++++++- .../db_engine_specs/test_teradata.py | 87 +++++++ 6 files changed, 338 insertions(+), 25 deletions(-) create mode 100644 tests/unit_tests/db_engine_specs/test_teradata.py diff --git a/docs/installation.rst b/docs/installation.rst index bca5cb4678996..3ae4a344b9724 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -569,7 +569,7 @@ Here's a list of some of the recommended packages. +------------------+-------------------------------------------------------------------+-------------------------------------------------+ | SQL Server | ``"apache-superset[mssql]"`` | ``mssql://`` | +------------------+-------------------------------------------------------------------+-------------------------------------------------+ -| Teradata | ``"apache-superset[teradata]"`` | ``teradata://`` | +| Teradata | ``"apache-superset[teradata]"`` | ``teradatasql://`` | +------------------+-------------------------------------------------------------------+-------------------------------------------------+ | Vertica | ``"apache-superset[vertical]"`` | ``vertica+vertica_python://`` | +------------------+-------------------------------------------------------------------+-------------------------------------------------+ @@ -753,16 +753,17 @@ Teradata The connection string for Teradata looks like this :: - teradata://{user}:{password}@{host} +The recommended connector library is +[teradatasql](https://github.com/Teradata/python-driver). +Also, see the latest on [PyPi](https://pypi.org/project/teradatasql/) -*Note*: Its required to have Teradata ODBC drivers installed and environment variables configured for proper work of sqlalchemy dialect. Teradata ODBC Drivers available here: https://downloads.teradata.com/download/connectivity/odbc-driver/linux +The connection string for Teradata looks like this: -Required environment variables: :: - - export ODBCINI=/.../teradata/client/ODBC_64/odbc.ini - export ODBCINST=/.../teradata/client/ODBC_64/odbcinst.ini +``` +teradatasql://{user}:{password}@{host} +``` -See `Teradata SQLAlchemy `_. +See `Teradata SQL Native Python Driver `_. Apache Drill ------------ diff --git a/docs/src/pages/docs/Connecting to Databases/index.mdx b/docs/src/pages/docs/Connecting to Databases/index.mdx index 8ef2f163bf82d..7f7867da773d0 100644 --- a/docs/src/pages/docs/Connecting to Databases/index.mdx +++ b/docs/src/pages/docs/Connecting to Databases/index.mdx @@ -56,7 +56,7 @@ A list of some of the recommended packages. |[Snowflake](/docs/databases/snowflake)|```pip install snowflake-sqlalchemy```|```snowflake://{user}:{password}@{account}.{region}/{database}?role={role}&warehouse={warehouse}```| |SQLite||```sqlite://```| |[SQL Server](/docs/databases/sql-server)|```pip install pymssql```|```mssql://```| -|[Teradata](/docs/databases/teradata)|```pip install sqlalchemy-teradata```|```teradata://{user}:{password}@{host}```| +|[Teradata](/docs/databases/teradata)|```pip install teradatasql```|```teradatasql://{user}:{password}@{host}```| |[Vertica](/docs/databases/vertica)|```pip install sqlalchemy-vertica-python```|```vertica+vertica_python://:@/```| *** diff --git a/docs/src/pages/docs/Connecting to Databases/teradata.mdx b/docs/src/pages/docs/Connecting to Databases/teradata.mdx index e594d4b0b8d12..ac5f285852a7f 100644 --- a/docs/src/pages/docs/Connecting to Databases/teradata.mdx +++ b/docs/src/pages/docs/Connecting to Databases/teradata.mdx @@ -9,21 +9,10 @@ version: 1 ## Teradata The recommended connector library is -[sqlalchemy-teradata](https://github.com/Teradata/sqlalchemy-teradata). +[teradatasql](https://github.com/Teradata/python-driver). The connection string for Teradata looks like this: ``` -teradata://{user}:{password}@{host} -``` - -Note: Its required to have Teradata ODBC drivers installed and environment variables configured for -proper work of sqlalchemy dialect. Teradata ODBC Drivers available here: -https://downloads.teradata.com/download/connectivity/odbc-driver/linux - -Required environment variables: - -``` -export ODBCINI=/.../teradata/client/ODBC_64/odbc.ini -export ODBCINST=/.../teradata/client/ODBC_64/odbcinst.ini +teradatasql://{user}:{password}@{host} ``` diff --git a/setup.py b/setup.py index d2e2590958c14..ba0f5bad90b4d 100644 --- a/setup.py +++ b/setup.py @@ -159,7 +159,7 @@ def get_git_sha() -> str: "snowflake": [ "snowflake-sqlalchemy==1.2.4" ], # PINNED! 1.2.5 introduced breaking changes requiring sqlalchemy>=1.4.0 - "teradata": ["sqlalchemy-teradata==0.9.0.dev0"], + "teradata": ["teradatasql>=16.20.0.23"], "thumbnails": ["Pillow>=8.3.2, <10.0.0"], "vertica": ["sqlalchemy-vertica-python>=0.5.9, < 0.6"], "netezza": ["nzalchemy>=11.0.2"], diff --git a/superset/db_engine_specs/teradata.py b/superset/db_engine_specs/teradata.py index 8fd1641064d6c..8e7589980b15b 100644 --- a/superset/db_engine_specs/teradata.py +++ b/superset/db_engine_specs/teradata.py @@ -14,13 +14,223 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from typing import Optional, Set + +import sqlparse +from sqlparse.sql import ( + Identifier, + IdentifierList, + Parenthesis, + remove_quotes, + Token, + TokenList, +) +from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace +from sqlparse.utils import imt + from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.sql_parse import Table + +PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"} +CTE_PREFIX = "CTE__" +JOIN = " JOIN" + + +def _extract_limit_from_query_td(statement: TokenList) -> Optional[int]: + td_limit_keywork = {"TOP", "SAMPLE"} + str_statement = str(statement) + str_statement = str_statement.replace("\n", " ").replace("\r", "") + token = str_statement.rstrip().split(" ") + token = [part for part in token if part] + limit = None + + for i, _ in enumerate(token): + if token[i].upper() in td_limit_keywork and len(token) - 1 > i: + try: + limit = int(token[i + 1]) + except ValueError: + limit = None + break + return limit + + +class ParsedQueryTeradata: + def __init__( + self, sql_statement: str, strip_comments: bool = False, uri_type: str = "None" + ): + + if strip_comments: + sql_statement = sqlparse.format(sql_statement, strip_comments=True) + + self.sql: str = sql_statement + self._tables: Set[Table] = set() + self._alias_names: Set[str] = set() + self._limit: Optional[int] = None + self.uri_type: str = uri_type + + self._parsed = sqlparse.parse(self.stripped()) + for statement in self._parsed: + self._limit = _extract_limit_from_query_td(statement) + + @property + def tables(self) -> Set[Table]: + if not self._tables: + for statement in self._parsed: + self._extract_from_token(statement) + + self._tables = { + table for table in self._tables if str(table) not in self._alias_names + } + return self._tables + + def stripped(self) -> str: + return self.sql.strip(" \t\n;") + + def _extract_from_token(self, token: Token) -> None: + """ + store a list of subtokens and store lists of + subtoken list. + + It extracts and from :param token: and loops + through all subtokens recursively. It finds table_name_preceding_token and + passes and to self._process_tokenlist to populate + + self._tables. + + :param token: instance of Token or child class, e.g. TokenList, to be processed + """ + if not hasattr(token, "tokens"): + return + + table_name_preceding_token = False + + for item in token.tokens: + if item.is_group and ( + not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis) + ): + self._extract_from_token(item) + + if item.ttype in Keyword and ( + item.normalized in PRECEDES_TABLE_NAME or item.normalized.endswith(JOIN) + ): + table_name_preceding_token = True + continue + + if item.ttype in Keyword: + table_name_preceding_token = False + continue + if table_name_preceding_token: + if isinstance(item, Identifier): + self._process_tokenlist(item) + elif isinstance(item, IdentifierList): + for item_list in item.get_identifiers(): + if isinstance(item_list, TokenList): + self._process_tokenlist(item_list) + elif isinstance(item, IdentifierList): + if any(not self._is_identifier(ItemList) for ItemList in item.tokens): + self._extract_from_token(item) + + @staticmethod + def _get_table(tlist: TokenList) -> Optional[Table]: + """ + Return the table if valid, i.e., conforms to the [[catalog.]schema.]table + construct. + + :param tlist: The SQL tokens + :returns: The table if the name conforms + """ + + # Strip the alias if present. + idx = len(tlist.tokens) + + if tlist.has_alias(): + ws_idx, _ = tlist.token_next_by(t=Whitespace) + + if ws_idx != -1: + idx = ws_idx + + tokens = tlist.tokens[:idx] + + odd_token_number = len(tokens) in (1, 3, 5) + qualified_name_parts = all( + imt(token, t=[Name, String]) for token in tokens[::2] + ) + dot_separators = all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2]) + if odd_token_number and qualified_name_parts and dot_separators: + return Table(*[remove_quotes(token.value) for token in tokens[::-2]]) + + return None + + @staticmethod + def _is_identifier(token: Token) -> bool: + return isinstance(token, (IdentifierList, Identifier)) + + def _process_tokenlist(self, token_list: TokenList) -> None: + """ + Add table names to table set + + :param token_list: TokenList to be processed + """ + # exclude subselects + if "(" not in str(token_list): + table = self._get_table(token_list) + if table and not table.table.startswith(CTE_PREFIX): + self._tables.add(table) + return + + # store aliases + if token_list.has_alias(): + self._alias_names.add(token_list.get_alias()) + + # some aliases are not parsed properly + if token_list.tokens[0].ttype == Name: + self._alias_names.add(token_list.tokens[0].value) + self._extract_from_token(token_list) + + def set_or_update_query_limit_td(self, new_limit: int) -> str: + td_sel_keywords = {"SELECT", "SEL"} + td_limit_keywords = {"TOP", "SAMPLE"} + statement = self._parsed[0] + + if not self._limit: + final_limit = new_limit + elif new_limit < self._limit: + final_limit = new_limit + else: + final_limit = self._limit + + str_statement = str(statement) + str_statement = str_statement.replace("\n", " ").replace("\r", "") + + tokens = str_statement.rstrip().split(" ") + tokens = [token for token in tokens if token] + + if limit_not_in_sql(str_statement, td_limit_keywords): + selects = [i for i, word in enumerate(tokens) if word in td_sel_keywords] + first_select = selects[0] + tokens.insert(first_select + 1, "TOP") + tokens.insert(first_select + 2, str(final_limit)) + + next_is_limit_token = False + new_tokens = [] + + for token in tokens: + if token.upper() in td_limit_keywords: + next_is_limit_token = True + elif next_is_limit_token: + if token.isdigit(): + token = str(final_limit) + next_is_limit_token = False + new_tokens.append(token) + + return " ".join(new_tokens) class TeradataEngineSpec(BaseEngineSpec): """Dialect for Teradata DB.""" - engine = "teradata" + engine = "teradatasql" engine_name = "Teradata" limit_method = LimitMethod.WRAP_SQL max_column_name_length = 30 # since 14.10 this is 128 @@ -32,7 +242,7 @@ class TeradataEngineSpec(BaseEngineSpec): "P1D": "TRUNC(CAST({col} as DATE), 'DDD')", "P1W": "TRUNC(CAST({col} as DATE), 'WW')", "P1M": "TRUNC(CAST({col} as DATE), 'MONTH')", - "P3M": "TRUNC(CAST({col} as DATE), 'Q')", + "P0.25Y": "TRUNC(CAST({col} as DATE), 'Q')", "P1Y": "TRUNC(CAST({col} as DATE), 'YEAR')", } @@ -43,3 +253,29 @@ def epoch_to_dttm(cls) -> str: "AT 0)) AT 0) + (({col} MOD 86400) * INTERVAL '00:00:01' " "HOUR TO SECOND) AS TIMESTAMP(0))" ) + + @classmethod + def apply_limit_to_sql( + cls, sql: str, limit: int, database: str = "Database", force: bool = False + ) -> str: + """ + Alters the SQL statement to apply a TOP clause + The function overwrites similar function in base.py because Teradata doesn't + support LIMIT syntax + :param sql: SQL query + :param limit: Maximum number of rows to be returned by the query + :param database: Database instance + :return: SQL query with limit clause + """ + + parsed_query = ParsedQueryTeradata(sql) + sql = parsed_query.set_or_update_query_limit_td(limit) + + return sql + + +def limit_not_in_sql(sql: str, limit_words: Set[str]) -> bool: + for limit_word in limit_words: + if limit_word in sql: + return False + return True diff --git a/tests/unit_tests/db_engine_specs/test_teradata.py b/tests/unit_tests/db_engine_specs/test_teradata.py new file mode 100644 index 0000000000000..8d9fc08c4a631 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_teradata.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, import-outside-toplevel, protected-access + +from flask.ctx import AppContext + + +def test_ParsedQueryTeradata_lower_limit(app_context: AppContext) -> None: + """ + Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(`` + + The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in + other dialects. + """ + from superset.db_engine_specs.teradata import TeradataEngineSpec + + sql = "SEL TOP 1000 * FROM My_table;" + limit = 100 + + assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == ( + "SEL TOP 100 * FROM My_table" + ) + + +def test_ParsedQueryTeradata_higher_limit(app_context: AppContext) -> None: + """ + Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(`` + + The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in + other dialects. + """ + from superset.db_engine_specs.teradata import TeradataEngineSpec + + sql = "SEL TOP 1000 * FROM My_table;" + limit = 10000 + + assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == ( + "SEL TOP 1000 * FROM My_table" + ) + + +def test_ParsedQueryTeradata_equal_limit(app_context: AppContext) -> None: + """ + Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(`` + + The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in + other dialects. + """ + from superset.db_engine_specs.teradata import TeradataEngineSpec + + sql = "SEL TOP 1000 * FROM My_table;" + limit = 1000 + + assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == ( + "SEL TOP 1000 * FROM My_table" + ) + + +def test_ParsedQueryTeradata_no_limit(app_context: AppContext) -> None: + """ + Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(`` + + The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in + other dialects. + """ + from superset.db_engine_specs.teradata import TeradataEngineSpec + + sql = "SEL * FROM My_table;" + limit = 1000 + + assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == ( + "SEL TOP 1000 * FROM My_table" + )