Skip to content

Commit 80b334f

Browse files
committed
test: add unit tests for trino.sqlalchemy
1 parent c7fbac3 commit 80b334f

File tree

6 files changed

+199
-1
lines changed

6 files changed

+199
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
all_require = kerberos_require + sqlalchemy_require
3232

33-
tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "mock", "pytz", "flake8"]
33+
tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "mock", "pytz", "flake8", "assertpy"]
3434

3535
py27_require = ["ipaddress", "typing"]
3636

tests/__init__.py

Whitespace-only changes.

tests/sqlalchemy/__init__.py

Whitespace-only changes.

tests/sqlalchemy/conftest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from assertpy import add_extension, assert_that
2+
from sqlalchemy.sql.sqltypes import ARRAY
3+
4+
from trino.sqlalchemy.datatype import SQLType, MAP, ROW
5+
6+
7+
def assert_sqltype(this: SQLType, that: SQLType):
8+
if isinstance(this, type):
9+
this = this()
10+
if isinstance(that, type):
11+
that = that()
12+
assert_that(type(this)).is_same_as(type(that))
13+
if isinstance(this, ARRAY):
14+
assert_sqltype(this.item_type, that.item_type)
15+
if this.dimensions is None or this.dimensions == 1:
16+
assert_that(that.dimensions).is_in(None, 1)
17+
else:
18+
assert_that(this.dimensions).is_equal_to(this.dimensions)
19+
elif isinstance(this, MAP):
20+
assert_sqltype(this.key_type, that.key_type)
21+
assert_sqltype(this.value_type, that.value_type)
22+
elif isinstance(this, ROW):
23+
assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types))
24+
for name, this_attr in this.attr_types.items():
25+
that_attr = this.attr_types[name]
26+
assert_sqltype(this_attr, that_attr)
27+
else:
28+
assert_that(str(this)).is_equal_to(str(that))
29+
30+
31+
@add_extension
32+
def is_sqltype(self, that):
33+
this = self.val
34+
assert_sqltype(this, that)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import pytest
2+
from assertpy import assert_that
3+
from sqlalchemy.sql.sqltypes import *
4+
from sqlalchemy.sql.type_api import TypeEngine
5+
6+
from trino.sqlalchemy import datatype
7+
from trino.sqlalchemy.datatype import MAP, ROW
8+
9+
10+
@pytest.mark.parametrize(
11+
'type_str, sql_type',
12+
datatype._type_map.items(),
13+
ids=datatype._type_map.keys()
14+
)
15+
def test_parse_simple_type(type_str: str, sql_type: TypeEngine):
16+
actual_type = datatype.parse_sqltype(type_str)
17+
if not isinstance(actual_type, type):
18+
actual_type = type(actual_type)
19+
assert_that(actual_type).is_equal_to(sql_type)
20+
21+
22+
parse_type_options_testcases = {
23+
'VARCHAR(10)': VARCHAR(10),
24+
'DECIMAL(20)': DECIMAL(20),
25+
'DECIMAL(20, 3)': DECIMAL(20, 3),
26+
}
27+
28+
29+
@pytest.mark.parametrize(
30+
'type_str, sql_type',
31+
parse_type_options_testcases.items(),
32+
ids=parse_type_options_testcases.keys()
33+
)
34+
def test_parse_type_options(type_str: str, sql_type: TypeEngine):
35+
actual_type = datatype.parse_sqltype(type_str)
36+
assert_that(actual_type).is_sqltype(sql_type)
37+
38+
39+
parse_array_testcases = {
40+
'array(integer)': ARRAY(INTEGER()),
41+
'array(varchar(10))': ARRAY(VARCHAR(10)),
42+
'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)),
43+
'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2),
44+
}
45+
46+
47+
@pytest.mark.parametrize(
48+
'type_str, sql_type',
49+
parse_array_testcases.items(),
50+
ids=parse_array_testcases.keys()
51+
)
52+
def test_parse_array(type_str: str, sql_type: ARRAY):
53+
actual_type = datatype.parse_sqltype(type_str)
54+
assert_that(actual_type).is_sqltype(sql_type)
55+
56+
57+
parse_map_testcases = {
58+
'map(char, integer)': MAP(CHAR(), INTEGER()),
59+
'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)),
60+
'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)),
61+
'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))),
62+
'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
63+
'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
64+
}
65+
66+
67+
@pytest.mark.parametrize(
68+
'type_str, sql_type',
69+
parse_map_testcases.items(),
70+
ids=parse_map_testcases.keys()
71+
)
72+
def test_parse_map(type_str: str, sql_type: ARRAY):
73+
actual_type = datatype.parse_sqltype(type_str)
74+
assert_that(actual_type).is_sqltype(sql_type)
75+
76+
77+
parse_row_testcases = {
78+
'row(a integer, b varchar)': ROW(dict(a=INTEGER(), b=VARCHAR())),
79+
'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))),
80+
'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))':
81+
ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))),
82+
}
83+
84+
85+
@pytest.mark.parametrize(
86+
'type_str, sql_type',
87+
parse_row_testcases.items(),
88+
ids=parse_row_testcases.keys()
89+
)
90+
def test_parse_row(type_str: str, sql_type: ARRAY):
91+
actual_type = datatype.parse_sqltype(type_str)
92+
assert_that(actual_type).is_sqltype(sql_type)
93+
94+
95+
parse_datetime_testcases = {
96+
'date': DATE(),
97+
'time': TIME(),
98+
'time with time zone': TIME(timezone=True),
99+
'timestamp': TIMESTAMP(),
100+
'timestamp with time zone': TIMESTAMP(timezone=True),
101+
}
102+
103+
104+
@pytest.mark.parametrize(
105+
'type_str, sql_type',
106+
parse_datetime_testcases.items(),
107+
ids=parse_datetime_testcases.keys()
108+
)
109+
def test_parse_datetime(type_str: str, sql_type: ARRAY):
110+
actual_type = datatype.parse_sqltype(type_str)
111+
assert_that(actual_type).is_sqltype(sql_type)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import *
2+
3+
import pytest
4+
from assertpy import assert_that
5+
6+
from trino.sqlalchemy import datatype
7+
8+
split_string_testcases = {
9+
'10': ['10'],
10+
'10,3': ['10', '3'],
11+
'varchar': ['varchar'],
12+
'varchar,int': ['varchar', 'int'],
13+
'varchar,int,float': ['varchar', 'int', 'float'],
14+
'array(varchar)': ['array(varchar)'],
15+
'array(varchar),int': ['array(varchar)', 'int'],
16+
'array(varchar(20))': ['array(varchar(20))'],
17+
'array(varchar(20)),int': ['array(varchar(20))', 'int'],
18+
'array(varchar(20)),array(varchar(20))': ['array(varchar(20))', 'array(varchar(20))'],
19+
'map(varchar, integer),int': ['map(varchar, integer)', 'int'],
20+
'map(varchar(20), integer),int': ['map(varchar(20), integer)', 'int'],
21+
'map(varchar(20), varchar(20)),int': ['map(varchar(20), varchar(20))', 'int'],
22+
'map(varchar(20), varchar(20)),array(varchar)': ['map(varchar(20), varchar(20))', 'array(varchar)'],
23+
'row(first_name varchar(20), last_name varchar(20)),int':
24+
['row(first_name varchar(20), last_name varchar(20))', 'int'],
25+
}
26+
27+
28+
@pytest.mark.parametrize(
29+
'input_string, output_strings',
30+
split_string_testcases.items(),
31+
ids=split_string_testcases.keys()
32+
)
33+
def test_split_string(input_string: str, output_strings: List[str]):
34+
actual = list(datatype.split(input_string))
35+
assert_that(actual).is_equal_to(output_strings)
36+
37+
38+
split_delimiter_testcases = [
39+
('first,second', ',', ['first', 'second']),
40+
('first second', ' ', ['first', 'second']),
41+
('first|second', '|', ['first', 'second']),
42+
('first,second third', ',', ['first', 'second third']),
43+
('first,second third', ' ', ['first,second', 'third']),
44+
]
45+
46+
47+
@pytest.mark.parametrize(
48+
'input_string, delimiter, output_strings',
49+
split_delimiter_testcases,
50+
)
51+
def test_split_delimiter(input_string: str, delimiter: str, output_strings: List[str]):
52+
actual = list(datatype.split(input_string, delimiter=delimiter))
53+
assert_that(actual).is_equal_to(output_strings)

0 commit comments

Comments
 (0)