Skip to content

Commit 67c104d

Browse files
committed
Support decimal, date, time, timestamp with time zone and timestamp
1 parent 65506e8 commit 67c104d

File tree

3 files changed

+178
-7
lines changed

3 files changed

+178
-7
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212
import math
13-
from datetime import datetime
13+
from datetime import datetime, time
14+
from decimal import Decimal
1415

1516
import pytest
1617
import pytz
@@ -123,22 +124,105 @@ def test_string_query_param(trino_connection):
123124
assert rows[0][0] == "six'"
124125

125126

127+
def test_decimal_query_param(trino_connection):
128+
cur = trino_connection.cursor()
129+
130+
cur.execute("SELECT ?", params=(Decimal('0.142857'),))
131+
rows = cur.fetchall()
132+
133+
assert rows[0][0] == Decimal('0.142857')
134+
135+
126136
def test_datetime_query_param(trino_connection):
127137
cur = trino_connection.cursor()
128138

129-
cur.execute("SELECT ?", params=(datetime(2020, 1, 1, 0, 0, 0),))
139+
params = datetime(2020, 1, 1, 16, 43, 22, 320000)
140+
141+
cur.execute("SELECT ?", params=(params,))
130142
rows = cur.fetchall()
131143

132-
assert rows[0][0] == "2020-01-01 00:00:00.000"
144+
assert rows[0][0] == params
145+
assert cur.description[0][1] == "timestamp"
146+
147+
148+
def test_datetime_with_time_zone_query_param(trino_connection):
149+
cur = trino_connection.cursor()
150+
151+
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('CET'))
133152

134153
cur.execute("SELECT ?",
135-
params=(datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc),))
154+
params=(params,))
136155
rows = cur.fetchall()
137156

138-
assert rows[0][0] == "2020-01-01 00:00:00.000 UTC"
157+
assert rows[0][0] == params
139158
assert cur.description[0][1] == "timestamp with time zone"
140159

141160

161+
def test_special_datetimes_query_param(trino_connection):
162+
cur = trino_connection.cursor()
163+
164+
for special_date in (
165+
datetime.fromtimestamp(1603589478, pytz.timezone('Europe/Warsaw')),
166+
):
167+
params = special_date
168+
169+
cur.execute("SELECT ?", params=(params,))
170+
rows = cur.fetchall()
171+
172+
assert rows[0][0] == params
173+
174+
175+
def test_date_query_param(trino_connection):
176+
cur = trino_connection.cursor()
177+
178+
params = datetime(2020, 1, 1, 0, 0, 0).date()
179+
180+
cur.execute("SELECT ?", params=(params,))
181+
rows = cur.fetchall()
182+
183+
assert rows[0][0] == params
184+
185+
186+
def test_special_dates_query_param(trino_connection):
187+
cur = trino_connection.cursor()
188+
189+
for params in (
190+
# datetime(-1, 1, 1, 0, 0, 0).date(),
191+
# datetime(0, 1, 1, 0, 0, 0).date(),
192+
datetime(1752, 9, 4, 0, 0, 0).date(),
193+
datetime(1970, 1, 1, 0, 0, 0).date(),
194+
):
195+
cur.execute("SELECT ?", params=(params,))
196+
rows = cur.fetchall()
197+
198+
assert rows[0][0] == params
199+
200+
201+
def test_time_query_param(trino_connection):
202+
cur = trino_connection.cursor()
203+
204+
params = time(12, 3, 44, 333000)
205+
206+
cur.execute("SELECT ?", params=(params,))
207+
rows = cur.fetchall()
208+
209+
assert rows[0][0] == params
210+
211+
212+
@pytest.mark.skip(reason="time with time zone currently not supported")
213+
def test_time_with_time_zone_query_param(trino_connection):
214+
cur = trino_connection.cursor()
215+
216+
params = time(16, 43, 22, 320000, tzinfo=pytz.timezone('CET'))
217+
218+
cur.execute("SELECT ?",
219+
params=(params,))
220+
rows = cur.fetchall()
221+
222+
assert rows[0][0] == params
223+
assert cur.description[0][1] == "time with time zone"
224+
225+
142226
def test_array_query_param(trino_connection):
143227
cur = trino_connection.cursor()
144228

@@ -158,6 +242,38 @@ def test_array_query_param(trino_connection):
158242
assert rows[0][0] == "array(integer)"
159243

160244

245+
def test_array_timestamp_query_param(trino_connection):
246+
cur = trino_connection.cursor()
247+
248+
params = [datetime(2020, 1, 1, 0, 0, 0), datetime(2020, 1, 2, 0, 0, 0)]
249+
250+
cur.execute("SELECT ?", params=(params,))
251+
rows = cur.fetchall()
252+
253+
assert rows[0][0] == params
254+
255+
cur.execute("SELECT TYPEOF(?)", params=(params,))
256+
rows = cur.fetchall()
257+
258+
assert rows[0][0] == "array(timestamp(6))"
259+
260+
261+
def test_array_timestamp_with_timezone_query_param(trino_connection):
262+
cur = trino_connection.cursor()
263+
264+
params = [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)]
265+
266+
cur.execute("SELECT ?", params=(params,))
267+
rows = cur.fetchall()
268+
269+
assert rows[0][0] == params
270+
271+
cur.execute("SELECT TYPEOF(?)", params=(params,))
272+
rows = cur.fetchall()
273+
274+
assert rows[0][0] == "array(timestamp(6) with time zone)"
275+
276+
161277
def test_dict_query_param(trino_connection):
162278
cur = trino_connection.cursor()
163279

trino/client.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
import copy
3737
import os
3838
import re
39+
from decimal import Decimal
40+
from datetime import datetime
41+
import pytz
3942
from typing import Any, Dict, List, Optional, Tuple, Union
4043
import urllib.parse
4144

@@ -494,12 +497,44 @@ def __iter__(self):
494497
for row in rows:
495498
self._rownumber += 1
496499
logger.debug("row %s", row)
497-
yield row
500+
yield self._map_to_python_types(row, self._query.columns)
498501

499502
@property
500503
def response_headers(self):
501504
return self._query.response_headers
502505

506+
@classmethod
507+
def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
508+
(value, data_type) = item
509+
510+
raw_type = data_type["typeSignature"]["rawType"]
511+
if isinstance(value, list):
512+
raw_type = {
513+
"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
514+
}
515+
return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
516+
elif "decimal" in raw_type:
517+
return Decimal(value)
518+
elif raw_type == "date":
519+
return datetime.strptime(value, "%Y-%m-%d").date()
520+
elif raw_type == "timestamp with time zone":
521+
# TODO handle timezone numeric offset and add tests
522+
dt, tz = value.rsplit(' ', 1)
523+
return datetime.strptime(dt, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=pytz.timezone(tz))
524+
elif "timestamp" in raw_type:
525+
return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f")
526+
# elif raw_type == "time with time zone":
527+
# dt, tz = value.rsplit(' ', 1)
528+
# # TODO handle timezone numeric offset and add tests
529+
# return datetime.strptime(value, "%H:%M:%S.%f").time().replace(tzinfo=pytz.timezone(tz))
530+
elif "time" in raw_type:
531+
return datetime.strptime(value, "%H:%M:%S.%f").time()
532+
else:
533+
return value
534+
535+
def _map_to_python_types(self, row: List[Any], columns: List[Dict[str, Any]]) -> List[Any]:
536+
return list(map(self._map_to_python_type, zip(row, columns)))
537+
503538

504539
class TrinoQuery(object):
505540
"""Represent the execution of a SQL statement by Trino."""

trino/dbapi.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Fetch methods returns rows as a list of lists on purpose to let the caller
1818
decide to convert then to a list of tuples.
1919
"""
20-
20+
from decimal import Decimal
2121
from typing import Any, List, Optional # NOQA for mypy types
2222

2323
import copy
@@ -360,11 +360,28 @@ def _format_prepared_param(self, param):
360360
return "X'%s'" % param.hex()
361361

362362
if isinstance(param, datetime.datetime):
363+
# TODO: support numeric offset and add tests
363364
datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f %Z")
364365
# strip trailing whitespace if param has no zone
365366
datetime_str = datetime_str.rstrip(" ")
366367
return "TIMESTAMP '%s'" % datetime_str
367368

369+
if isinstance(param, datetime.time):
370+
time_str = param.strftime("%H:%M:%S.%f")
371+
return "TIME '%s'" % time_str
372+
373+
# # TODO: how to distinguish between time and time with time zone in Python?
374+
# if isinstance(param, datetime.time):
375+
# time_str = param.strftime("%H:%M:%S.%f")[:-3]
376+
# # actually we lost the datetime here, so timezone is meaningless...
377+
# utc_offset = param.tzinfo.utcoffset(param)
378+
# time_tz = ":".join(map(lambda x: x.rjust(2, '0'), str(utc_offset).split(":")[:-1])).rjust(6, '+')
379+
# return "TIME '%s %s'" % (time_str, time_tz)
380+
381+
if isinstance(param, datetime.date):
382+
date_str = param.strftime("%Y-%m-%d")
383+
return "DATE '%s'" % date_str
384+
368385
if isinstance(param, list):
369386
return "ARRAY[%s]" % ','.join(map(self._format_prepared_param, param))
370387

@@ -379,6 +396,9 @@ def _format_prepared_param(self, param):
379396
if isinstance(param, uuid.UUID):
380397
return "UUID '%s'" % param
381398

399+
if isinstance(param, Decimal):
400+
return "DECIMAL '%s'" % param
401+
382402
raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))
383403

384404
def _deallocate_prepare_statement(self, added_prepare_header, statement_name):

0 commit comments

Comments
 (0)