Skip to content

Commit b039b35

Browse files
committed
Support for TIME(p) and TIMESTAMP(p) to SQLAlchemy
1 parent 771eec3 commit b039b35

File tree

4 files changed

+74
-18
lines changed

4 files changed

+74
-18
lines changed

tests/unit/sqlalchemy/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313
from sqlalchemy.sql.sqltypes import ARRAY
1414

15-
from trino.sqlalchemy.datatype import MAP, ROW, SQLType
15+
from trino.sqlalchemy.datatype import MAP, ROW, SQLType, TIMESTAMP, TIME
1616

1717

1818
@pytest.fixture(scope="session")
@@ -40,6 +40,15 @@ def _assert_sqltype(this: SQLType, that: SQLType):
4040
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
4141
assert this_attr[0] == that_attr[0]
4242
_assert_sqltype(this_attr[1], that_attr[1])
43+
44+
elif isinstance(this, TIME):
45+
assert this.precision == that.precision
46+
assert this.timezone == that.timezone
47+
48+
elif isinstance(this, TIMESTAMP):
49+
assert this.precision == that.precision
50+
assert this.timezone == that.timezone
51+
4352
else:
4453
assert str(this) == str(that)
4554

tests/unit/sqlalchemy/test_datatype_parse.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616
ARRAY,
1717
INTEGER,
1818
DECIMAL,
19-
DATE,
20-
TIME,
21-
TIMESTAMP,
19+
DATE
2220
)
2321
from sqlalchemy.sql.type_api import TypeEngine
2422

2523
from trino.sqlalchemy import datatype
26-
from trino.sqlalchemy.datatype import MAP, ROW
24+
from trino.sqlalchemy.datatype import (
25+
MAP,
26+
ROW,
27+
TIME,
28+
TIMESTAMP
29+
)
2730

2831

2932
@pytest.mark.parametrize(
@@ -65,8 +68,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype):
6568
"CHAR(10)": CHAR(10),
6669
"VARCHAR(10)": VARCHAR(10),
6770
"DECIMAL(20)": DECIMAL(20),
68-
"DECIMAL(20, 3)": DECIMAL(20, 3),
69-
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
71+
"DECIMAL(20, 3)": DECIMAL(20, 3)
7072
}
7173

7274

@@ -142,8 +144,8 @@ def test_parse_map(type_str: str, sql_type: ARRAY, assert_sqltype):
142144
),
143145
"row(min timestamp(6) with time zone, max timestamp(6) with time zone)": ROW(
144146
attr_types=[
145-
("min", TIMESTAMP(timezone=True)),
146-
("max", TIMESTAMP(timezone=True)),
147+
("min", TIMESTAMP(6, timezone=True)),
148+
("max", TIMESTAMP(6, timezone=True)),
147149
]
148150
),
149151
'row("first name" varchar, "last name" varchar)': ROW(
@@ -173,12 +175,16 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
173175

174176

175177
parse_datetime_testcases = {
176-
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
177178
"date": DATE(),
178179
"time": TIME(),
180+
"time(3)": TIME(3, timezone=False),
181+
"time(6)": TIME(6),
182+
"time(12) with time zone": TIME(12, timezone=True),
179183
"time with time zone": TIME(timezone=True),
180-
"timestamp": TIMESTAMP(),
181-
"timestamp with time zone": TIMESTAMP(timezone=True),
184+
"timestamp(3)": TIMESTAMP(3, timezone=False),
185+
"timestamp(6)": TIMESTAMP(6),
186+
"timestamp(12) with time zone": TIMESTAMP(12, timezone=True),
187+
"timestamp with time zone": TIMESTAMP(timezone=True)
182188
}
183189

184190

@@ -187,6 +193,6 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
187193
parse_datetime_testcases.items(),
188194
ids=parse_datetime_testcases.keys(),
189195
)
190-
def test_parse_datetime(type_str: str, sql_type: ARRAY, assert_sqltype):
196+
def test_parse_datetime(type_str: str, sql_type: TypeEngine, assert_sqltype):
191197
actual_type = datatype.parse_sqltype(type_str)
192198
assert_sqltype(actual_type, sql_type)

trino/sqlalchemy/compiler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,26 @@ def visit_BLOB(self, type_, **kw):
147147
def visit_DATETIME(self, type_, **kw):
148148
return self.visit_TIMESTAMP(type_, **kw)
149149

150+
def visit_TIMESTAMP(self, type_, **kw):
151+
datatype = "TIMESTAMP"
152+
precision = getattr(type_, "precision", None)
153+
if precision:
154+
datatype += f"({precision})"
155+
if getattr(type_, "timezone", False):
156+
datatype += " WITH TIME ZONE"
157+
158+
return datatype
159+
160+
def visit_TIME(self, type_, **kw):
161+
datatype = "TIME"
162+
precision = getattr(type_, "precision", None)
163+
if precision:
164+
datatype += f"({precision})"
165+
if getattr(type_, "timezone", False):
166+
datatype += " WITH TIME ZONE"
167+
168+
return datatype
169+
150170

151171
class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
152172
reserved_words = RESERVED_WORDS

trino/sqlalchemy/datatype.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212
import re
13-
from typing import Iterator, List, Optional, Tuple, Type, Union
13+
from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any
1414

1515
from sqlalchemy import util
1616
from sqlalchemy.sql import sqltypes
@@ -55,6 +55,22 @@ def python_type(self):
5555
return list
5656

5757

58+
class TIME(sqltypes.TIME):
59+
__visit_name__ = "TIME"
60+
61+
def __init__(self, precision=None, timezone=False):
62+
super(TIME, self).__init__(timezone=timezone)
63+
self.precision = precision
64+
65+
66+
class TIMESTAMP(sqltypes.TIMESTAMP):
67+
__visit_name__ = "TIMESTAMP"
68+
69+
def __init__(self, precision=None, timezone=False):
70+
super(TIMESTAMP, self).__init__(timezone=timezone)
71+
self.precision = precision
72+
73+
5874
# https://trino.io/docs/current/language/types.html
5975
_type_map = {
6076
# === Boolean ===
@@ -77,8 +93,10 @@ def python_type(self):
7793
"json": sqltypes.JSON,
7894
# === Date and time ===
7995
"date": sqltypes.DATE,
80-
"time": sqltypes.TIME,
81-
"timestamp": sqltypes.TIMESTAMP,
96+
"time": TIME,
97+
"time with time zone": TIME,
98+
"timestamp": TIMESTAMP,
99+
"timestamp with time zone": TIMESTAMP,
82100
# 'interval year to month':
83101
# 'interval day to second':
84102
#
@@ -193,7 +211,10 @@ def parse_sqltype(type_str: str) -> TypeEngine:
193211
type_class = _type_map[type_name]
194212
type_args = [int(o.strip()) for o in type_opts.split(",")] if type_opts else []
195213
if type_name in ("time", "timestamp"):
196-
type_kwargs = dict(timezone=type_str.endswith("with time zone"))
197-
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
214+
type_kwargs: Dict[str, Any] = dict()
215+
if type_str.endswith("with time zone"):
216+
type_kwargs["timezone"] = True
217+
if type_opts is not None:
218+
type_kwargs["precision"] = int(type_opts)
198219
return type_class(**type_kwargs)
199220
return type_class(*type_args)

0 commit comments

Comments
 (0)