Skip to content

Commit

Permalink
test: add unit tests for trino.sqlalchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
dungdm93 committed Apr 12, 2021
1 parent ee389bc commit 75df687
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
pip install .[tests]
- name: Run tests
run: |
pytest -s tests/ integration_tests/
pytest .
- name: Run linter
run: |
flake8
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

all_require = kerberos_require + sqlalchemy_require

tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "flake8"]
tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "flake8", "assertpy"]

setup(
name="trino",
Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/sqlalchemy/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions tests/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from assertpy import add_extension, assert_that
from sqlalchemy.sql.sqltypes import ARRAY

from trino.sqlalchemy.datatype import SQLType, MAP, ROW


def assert_sqltype(this: SQLType, that: SQLType):
if isinstance(this, type):
this = this()
if isinstance(that, type):
that = that()
assert_that(type(this)).is_same_as(type(that))
if isinstance(this, ARRAY):
assert_sqltype(this.item_type, that.item_type)
if this.dimensions is None or this.dimensions == 1:
assert_that(that.dimensions).is_in(None, 1)
else:
assert_that(this.dimensions).is_equal_to(this.dimensions)
elif isinstance(this, MAP):
assert_sqltype(this.key_type, that.key_type)
assert_sqltype(this.value_type, that.value_type)
elif isinstance(this, ROW):
assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types))
for name, this_attr in this.attr_types.items():
that_attr = this.attr_types[name]
assert_sqltype(this_attr, that_attr)
else:
assert_that(str(this)).is_equal_to(str(that))


@add_extension
def is_sqltype(self, that):
this = self.val
assert_sqltype(this, that)
111 changes: 111 additions & 0 deletions tests/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import pytest
from assertpy import assert_that
from sqlalchemy.sql.sqltypes import *
from sqlalchemy.sql.type_api import TypeEngine

from trino.sqlalchemy import datatype
from trino.sqlalchemy.datatype import MAP, ROW


@pytest.mark.parametrize(
'type_str, sql_type',
datatype._type_map.items(),
ids=datatype._type_map.keys()
)
def test_parse_simple_type(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
if not isinstance(actual_type, type):
actual_type = type(actual_type)
assert_that(actual_type).is_equal_to(sql_type)


parse_type_options_testcases = {
'VARCHAR(10)': VARCHAR(10),
'DECIMAL(20)': DECIMAL(20),
'DECIMAL(20, 3)': DECIMAL(20, 3),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_type_options_testcases.items(),
ids=parse_type_options_testcases.keys()
)
def test_parse_type_options(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_array_testcases = {
'array(integer)': ARRAY(INTEGER()),
'array(varchar(10))': ARRAY(VARCHAR(10)),
'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)),
'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_array_testcases.items(),
ids=parse_array_testcases.keys()
)
def test_parse_array(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_map_testcases = {
'map(char, integer)': MAP(CHAR(), INTEGER()),
'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)),
'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)),
'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))),
'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_map_testcases.items(),
ids=parse_map_testcases.keys()
)
def test_parse_map(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_row_testcases = {
'row(a integer, b varchar)': ROW(dict(a=INTEGER(), b=VARCHAR())),
'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))),
'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))':
ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_row_testcases.items(),
ids=parse_row_testcases.keys()
)
def test_parse_row(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_datetime_testcases = {
'date': DATE(),
'time': TIME(),
'time with time zone': TIME(timezone=True),
'timestamp': TIMESTAMP(),
'timestamp with time zone': TIMESTAMP(timezone=True),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_datetime_testcases.items(),
ids=parse_datetime_testcases.keys()
)
def test_parse_datetime(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)
53 changes: 53 additions & 0 deletions tests/sqlalchemy/test_datatype_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import *

import pytest
from assertpy import assert_that

from trino.sqlalchemy import datatype

split_string_testcases = {
'10': ['10'],
'10,3': ['10', '3'],
'varchar': ['varchar'],
'varchar,int': ['varchar', 'int'],
'varchar,int,float': ['varchar', 'int', 'float'],
'array(varchar)': ['array(varchar)'],
'array(varchar),int': ['array(varchar)', 'int'],
'array(varchar(20))': ['array(varchar(20))'],
'array(varchar(20)),int': ['array(varchar(20))', 'int'],
'array(varchar(20)),array(varchar(20))': ['array(varchar(20))', 'array(varchar(20))'],
'map(varchar, integer),int': ['map(varchar, integer)', 'int'],
'map(varchar(20), integer),int': ['map(varchar(20), integer)', 'int'],
'map(varchar(20), varchar(20)),int': ['map(varchar(20), varchar(20))', 'int'],
'map(varchar(20), varchar(20)),array(varchar)': ['map(varchar(20), varchar(20))', 'array(varchar)'],
'row(first_name varchar(20), last_name varchar(20)),int':
['row(first_name varchar(20), last_name varchar(20))', 'int'],
}


@pytest.mark.parametrize(
'input_string, output_strings',
split_string_testcases.items(),
ids=split_string_testcases.keys()
)
def test_split_string(input_string: str, output_strings: List[str]):
actual = list(datatype.split(input_string))
assert_that(actual).is_equal_to(output_strings)


split_delimiter_testcases = [
('first,second', ',', ['first', 'second']),
('first second', ' ', ['first', 'second']),
('first|second', '|', ['first', 'second']),
('first,second third', ',', ['first', 'second third']),
('first,second third', ' ', ['first,second', 'third']),
]


@pytest.mark.parametrize(
'input_string, delimiter, output_strings',
split_delimiter_testcases,
)
def test_split_delimiter(input_string: str, delimiter: str, output_strings: List[str]):
actual = list(datatype.split(input_string, delimiter=delimiter))
assert_that(actual).is_equal_to(output_strings)

0 comments on commit 75df687

Please sign in to comment.