Skip to content

Commit

Permalink
feat: support sqlalchemy
Browse files Browse the repository at this point in the history
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
  • Loading branch information
dungdm93 committed Nov 25, 2021
1 parent c8144de commit c5d19b5
Show file tree
Hide file tree
Showing 18 changed files with 1,196 additions and 55 deletions.
11 changes: 11 additions & 0 deletions integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
28 changes: 16 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,23 @@

import ast
import re
from setuptools import setup
import textwrap

from setuptools import setup

_version_re = re.compile(r"__version__\s+=\s+(.*)")


with open("trino/__init__.py", "rb") as f:
trino_version = _version_re.search(f.read().decode("utf-8"))
assert trino_version is not None
version = str(ast.literal_eval(trino_version.group(1)))


kerberos_require = ["requests_kerberos"]
sqlalchemy_require = ["sqlalchemy~=1.3"]

all_require = kerberos_require + []
all_require = kerberos_require + sqlalchemy_require

tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click"]
tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click", "assertpy"]

setup(
name="trino",
Expand All @@ -44,19 +43,17 @@
description="Client for the Trino distributed SQL Engine",
long_description=textwrap.dedent(
"""
Client for Trino (https://trino.io), a distributed SQL engine for
interactive and batch big data processing. Provides a low-level client and
a DBAPI 2.0 implementation.
"""
Client for Trino (https://trino.io), a distributed SQL engine for
interactive and batch big data processing. Provides a low-level client and
a DBAPI 2.0 implementation.
"""
),
license="Apache 2.0",
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX",
"Operating System :: Microsoft :: Windows",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
Expand All @@ -66,13 +63,20 @@
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"Topic :: Database",
"Topic :: Database :: Front-Ends",
],
python_requires='>=3.6',
install_requires=["requests"],
extras_require={
"all": all_require,
"kerberos": kerberos_require,
"sqlalchemy": sqlalchemy_require,
"tests": tests_require,
},
entry_points={
"sqlalchemy.dialects": [
"trino = trino.sqlalchemy.dialect:TrinoDialect",
]
},
)
Empty file added tests/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import pytz

import trino
import trino.dbapi
from trino.exceptions import TrinoQueryError
from trino.transaction import IsolationLevel

Expand Down
11 changes: 11 additions & 0 deletions tests/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
46 changes: 46 additions & 0 deletions tests/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from assertpy import add_extension, assert_that
from sqlalchemy.sql.sqltypes import ARRAY

from trino.sqlalchemy.datatype import MAP, ROW, SQLType


def assert_sqltype(this: SQLType, that: SQLType):
if isinstance(this, type):
this = this()
if isinstance(that, type):
that = that()
assert_that(type(this)).is_same_as(type(that))
if isinstance(this, ARRAY):
assert_sqltype(this.item_type, that.item_type)
if this.dimensions is None or this.dimensions == 1:
# ARRAY(dimensions=None) == ARRAY(dimensions=1)
assert_that(that.dimensions).is_in(None, 1)
else:
assert_that(this.dimensions).is_equal_to(this.dimensions)
elif isinstance(this, MAP):
assert_sqltype(this.key_type, that.key_type)
assert_sqltype(this.value_type, that.value_type)
elif isinstance(this, ROW):
assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types))
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
assert_that(this_attr[0]).is_equal_to(that_attr[0])
assert_sqltype(this_attr[1], that_attr[1])
else:
assert_that(str(this)).is_equal_to(str(that))


@add_extension
def is_sqltype(self, that):
this = self.val
assert_sqltype(this, that)
183 changes: 183 additions & 0 deletions tests/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from assertpy import assert_that
from sqlalchemy.sql.sqltypes import (
CHAR, VARCHAR,
ARRAY,
INTEGER, DECIMAL,
DATE, TIME, TIMESTAMP
)
from sqlalchemy.sql.type_api import TypeEngine

from trino.sqlalchemy import datatype
from trino.sqlalchemy.datatype import MAP, ROW


@pytest.mark.parametrize(
'type_str, sql_type',
datatype._type_map.items(),
ids=datatype._type_map.keys()
)
def test_parse_simple_type(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
if not isinstance(actual_type, type):
actual_type = type(actual_type)
assert_that(actual_type).is_equal_to(sql_type)


parse_cases_testcases = {
'char(10)': CHAR(10),
'Char(10)': CHAR(10),
'char': CHAR(),
'cHaR': CHAR(),
'VARCHAR(10)': VARCHAR(10),
'varCHAR(10)': VARCHAR(10),
'VARchar(10)': VARCHAR(10),
'VARCHAR': VARCHAR(),
'VaRchAr': VARCHAR(),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_cases_testcases.items(),
ids=parse_cases_testcases.keys()
)
def test_parse_cases(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_type_options_testcases = {
'CHAR(10)': CHAR(10),
'VARCHAR(10)': VARCHAR(10),
'DECIMAL(20)': DECIMAL(20),
'DECIMAL(20, 3)': DECIMAL(20, 3),
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_type_options_testcases.items(),
ids=parse_type_options_testcases.keys()
)
def test_parse_type_options(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_array_testcases = {
'array(integer)': ARRAY(INTEGER()),
'array(varchar(10))': ARRAY(VARCHAR(10)),
'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)),
'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2),
'array(map(char, integer))': ARRAY(MAP(CHAR(), INTEGER())),
'array(row(a integer, b varchar))': ARRAY(ROW([("a", INTEGER()), ("b", VARCHAR())])),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_array_testcases.items(),
ids=parse_array_testcases.keys()
)
def test_parse_array(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_map_testcases = {
'map(char, integer)': MAP(CHAR(), INTEGER()),
'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)),
'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)),
'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))),
'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_map_testcases.items(),
ids=parse_map_testcases.keys()
)
def test_parse_map(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_row_testcases = {
'row(a integer, b varchar)':
ROW(attr_types=[
("a", INTEGER()),
("b", VARCHAR()),
]),
'row(a varchar(20), b decimal(20,3))':
ROW(attr_types=[
("a", VARCHAR(20)),
("b", DECIMAL(20, 3)),
]),
'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))':
ROW(attr_types=[
("x", ARRAY(VARCHAR(10))),
("y", ARRAY(VARCHAR(10), dimensions=2)),
("z", DECIMAL(20, 3)),
]),
'row(min timestamp(6) with time zone, max timestamp(6) with time zone)':
ROW(attr_types=[
("min", TIMESTAMP(timezone=True)),
("max", TIMESTAMP(timezone=True)),
]),
'row("first name" varchar, "last name" varchar)':
ROW(attr_types=[
("first name", VARCHAR()),
("last name", VARCHAR()),
]),
'row("foo,bar" varchar, "foo(bar)" varchar, "foo\\"bar" varchar)':
ROW(attr_types=[
(r'foo,bar', VARCHAR()),
(r'foo(bar)', VARCHAR()),
(r'foo"bar', VARCHAR()),
]),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_row_testcases.items(),
ids=parse_row_testcases.keys()
)
def test_parse_row(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_datetime_testcases = {
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
'date': DATE(),
'time': TIME(),
'time with time zone': TIME(timezone=True),
'timestamp': TIMESTAMP(),
'timestamp with time zone': TIMESTAMP(timezone=True),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_datetime_testcases.items(),
ids=parse_datetime_testcases.keys()
)
def test_parse_datetime(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)
Loading

0 comments on commit c5d19b5

Please sign in to comment.