Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable running of tests in tests/db_engine_specs #8902

Merged
merged 4 commits into from
Dec 31, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor and add tests for pyodbc.Row conversion
  • Loading branch information
robdiciuccio committed Dec 30, 2019
commit dd4af7ac7eb98cc112fb095338c9a41282244013
12 changes: 12 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,3 +856,15 @@ def column_datatype_to_string(
:return: Compiled column type
"""
return sqla_column_type.compile(dialect=dialect).upper()

@staticmethod
def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple]:
"""
Convert pyodbc.Row objects from `fetch_data` to tuples.

:param data: List of tuples or pyodbc.Row objects
:return: List of tuples
"""
if data and type(data[0]).__name__ == "Row":
data = [tuple(row) for row in data]
return data
4 changes: 1 addition & 3 deletions superset/db_engine_specs/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,4 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
if data and type(data[0]).__name__ == "Row":
data = [tuple(row) for row in data]
return data
return cls.pyodbc_rows_to_tuples(data)
5 changes: 2 additions & 3 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
@classmethod
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
data = super().fetch_data(cursor, limit)
if data and type(data[0]).__name__ == "Row":
data = [tuple(row) for row in data]
return data
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)

column_types = [
(String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)),
Expand Down
25 changes: 25 additions & 0 deletions tests/db_engine_specs/base_engine_spec_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
from unittest import mock

from tests.test_app import app # isort:skip
Expand All @@ -23,6 +24,8 @@
from superset.utils.core import get_example_database
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase

from ..fixtures.pyodbcRow import Row


class DbEngineSpecsTests(DbEngineSpecTestCase):
def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec):
Expand Down Expand Up @@ -206,3 +209,25 @@ def test_column_datatype_to_string(self):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm))

def test_pyodbc_rows_to_tuples(self):
# Test for case when pyodbc.Row is returned (odbc driver)
data = [
Row((1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000))),
Row((2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000))),
]
expected = [
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
]
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
self.assertListEqual(result, expected)

def test_pyodbc_rows_to_tuples_passthrough(self):
# Test for case when tuples are returned
data = [
(1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
]
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
self.assertListEqual(result, data)
15 changes: 15 additions & 0 deletions tests/db_engine_specs/mssql_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock

from sqlalchemy import column, table
from sqlalchemy.dialects import mssql
from sqlalchemy.sql import select
from sqlalchemy.types import String, UnicodeText

from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from tests.db_engine_specs.base_tests import DbEngineSpecTestCase

Expand Down Expand Up @@ -87,3 +90,15 @@ def test_convert_dttm(self):
MssqlEngineSpec.convert_dttm("SMALLDATETIME", dttm),
"CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)",
)

@mock.patch.object(
MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"
)
def test_fetch_data(self, mock_pyodbc_rows_to_tuples):
data = [(1, "foo")]
with mock.patch.object(
BaseEngineSpec, "fetch_data", return_value=data
) as mock_fetch:
result = MssqlEngineSpec.fetch_data(None, 0)
mock_pyodbc_rows_to_tuples.assert_called_once_with(data)
self.assertEqual(result, "converted")