Skip to content

Commit d751cb4

Browse files
committed
Add support for TIMEZONE
1 parent f97aea6 commit d751cb4

File tree

7 files changed

+82
-2
lines changed

7 files changed

+82
-2
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ repos:
1313
additional_dependencies:
1414
- "types-pytz"
1515
- "types-requests"
16+
- "types-tzlocal"

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"pytest",
4040
"pytest-runner",
4141
"click",
42+
"tzlocal",
4243
]
4344

4445
setup(
@@ -76,7 +77,7 @@
7677
"Topic :: Database :: Front-Ends",
7778
],
7879
python_requires='>=3.7',
79-
install_requires=["pytz", "requests"],
80+
install_requires=["pytz", "requests", "tzlocal"],
8081
extras_require={
8182
"all": all_require,
8283
"kerberos": kerberos_require,

tests/integration/test_dbapi_integration.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,3 +1120,15 @@ def test_prepared_statements(run_trino):
11201120
cur.execute('DEALLOCATE PREPARE test_prepared_statements')
11211121
cur.fetchall()
11221122
assert cur._request._client_session.prepared_statements == {}
1123+
1124+
1125+
def test_set_timezone_in_connection(run_trino):
1126+
_, host, port = run_trino
1127+
1128+
trino_connection = trino.dbapi.Connection(
1129+
host=host, port=port, user="test", catalog="tpch", timezone="Europe/Brussels"
1130+
)
1131+
cur = trino_connection.cursor()
1132+
cur.execute('SHOW TABLES FROM information_schema')
1133+
cur.fetchall()
1134+
assert cur._request.http_headers[constants.HEADER_TIMEZONE] == "Europe/Brussels"

tests/unit/test_client.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import threading
1414
import time
1515
import uuid
16+
from tzlocal import get_localzone_name
1617
from typing import Optional, Dict
1718
from unittest import mock
1819
from urllib.parse import urlparse
@@ -65,6 +66,7 @@ def test_request_headers(mock_get_and_post):
6566
schema = "test_schema"
6667
user = "test_user"
6768
source = "test_source"
69+
timezone = "Europe/Brussels"
6870
accept_encoding_header = "accept-encoding"
6971
accept_encoding_value = "identity,deflate,gzip"
7072
client_info_header = constants.HEADER_CLIENT_INFO
@@ -78,6 +80,7 @@ def test_request_headers(mock_get_and_post):
7880
source=source,
7981
catalog=catalog,
8082
schema=schema,
83+
timezone=timezone,
8184
headers={
8285
accept_encoding_header: accept_encoding_value,
8386
client_info_header: client_info_value,
@@ -93,9 +96,10 @@ def assert_headers(headers):
9396
assert headers[constants.HEADER_SOURCE] == source
9497
assert headers[constants.HEADER_USER] == user
9598
assert headers[constants.HEADER_SESSION] == ""
99+
assert headers[constants.HEADER_TIMEZONE] == timezone
96100
assert headers[accept_encoding_header] == accept_encoding_value
97101
assert headers[client_info_header] == client_info_value
98-
assert len(headers.keys()) == 8
102+
assert len(headers.keys()) == 9
99103

100104
req.post("URL")
101105
_, post_kwargs = post.call_args
@@ -1056,3 +1060,51 @@ def test_request_headers_role_empty(mock_get_and_post):
10561060
req.get("URL")
10571061
_, get_kwargs = get.call_args
10581062
assert_headers_with_roles(post_kwargs["headers"], None)
1063+
1064+
1065+
def assert_headers_timezone(headers: Dict[str, str], timezone: Optional[str]):
1066+
if timezone is None:
1067+
assert headers[constants.HEADER_TIMEZONE] == get_localzone_name()
1068+
else:
1069+
assert headers[constants.HEADER_TIMEZONE] == timezone
1070+
1071+
1072+
def test_request_headers_with_timezone(mock_get_and_post):
1073+
get, post = mock_get_and_post
1074+
1075+
req = TrinoRequest(
1076+
host="coordinator",
1077+
port=8080,
1078+
client_session=ClientSession(
1079+
user="test_user",
1080+
timezone="Europe/Brussels"
1081+
),
1082+
)
1083+
1084+
req.post("URL")
1085+
_, post_kwargs = post.call_args
1086+
assert_headers_timezone(post_kwargs["headers"], "Europe/Brussels")
1087+
1088+
req.get("URL")
1089+
_, get_kwargs = get.call_args
1090+
assert_headers_timezone(post_kwargs["headers"], "Europe/Brussels")
1091+
1092+
1093+
def test_request_headers_without_timezone(mock_get_and_post):
1094+
get, post = mock_get_and_post
1095+
1096+
req = TrinoRequest(
1097+
host="coordinator",
1098+
port=8080,
1099+
client_session=ClientSession(
1100+
user="test_user",
1101+
),
1102+
)
1103+
1104+
req.post("URL")
1105+
_, post_kwargs = post.call_args
1106+
assert_headers_timezone(post_kwargs["headers"], None)
1107+
1108+
req.get("URL")
1109+
_, get_kwargs = get.call_args
1110+
assert_headers_timezone(post_kwargs["headers"], None)

trino/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import urllib.parse
4444
from datetime import datetime, timedelta, timezone
4545
from decimal import Decimal
46+
from tzlocal import get_localzone_name
4647
from typing import Any, Dict, List, Optional, Tuple, Union
4748

4849
import pytz
@@ -100,6 +101,8 @@ class ClientSession(object):
100101
:param client_tags: Client tags as list of strings.
101102
:param roles: roles for the current session. Some connectors do not
102103
support role management. See connector documentation for more details.
104+
:param timezone: The timezone for query processing. Defaults to the timezone
105+
of the Trino cluster, and not the timezone of the client.
103106
"""
104107

105108
def __init__(
@@ -114,6 +117,7 @@ def __init__(
114117
extra_credential: List[Tuple[str, str]] = None,
115118
client_tags: List[str] = None,
116119
roles: Dict[str, str] = None,
120+
timezone: str = None,
117121
):
118122
self._user = user
119123
self._catalog = catalog
@@ -127,6 +131,7 @@ def __init__(
127131
self._roles = roles.copy() if roles is not None else {}
128132
self._prepared_statements: Dict[str, str] = {}
129133
self._object_lock = threading.Lock()
134+
self._timezone = timezone or get_localzone_name()
130135

131136
@property
132137
def user(self):
@@ -207,6 +212,11 @@ def prepared_statements(self, prepared_statements):
207212
with self._object_lock:
208213
self._prepared_statements = prepared_statements
209214

215+
@property
216+
def timezone(self):
217+
with self._object_lock:
218+
return self._timezone
219+
210220
def __getstate__(self):
211221
state = self.__dict__.copy()
212222
del state["_object_lock"]
@@ -408,6 +418,7 @@ def http_headers(self) -> Dict[str, str]:
408418
headers[constants.HEADER_SCHEMA] = self._client_session.schema
409419
headers[constants.HEADER_SOURCE] = self._client_session.source
410420
headers[constants.HEADER_USER] = self._client_session.user
421+
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
411422
if len(self._client_session.roles.values()):
412423
headers[constants.HEADER_ROLE] = ",".join(
413424
# ``name`` must not contain ``=``

trino/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
HEADER_CLIENT_INFO = "X-Trino-Client-Info"
3535
HEADER_CLIENT_TAGS = "X-Trino-Client-Tags"
3636
HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential"
37+
HEADER_TIMEZONE = "X-Trino-Time-Zone"
3738

3839
HEADER_SESSION = "X-Trino-Session"
3940
HEADER_SET_SESSION = "X-Trino-Set-Session"

trino/dbapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
client_tags=None,
112112
experimental_python_types=False,
113113
roles=None,
114+
timezone=None,
114115
):
115116
self.host = host
116117
self.port = port
@@ -130,6 +131,7 @@ def __init__(
130131
extra_credential=extra_credential,
131132
client_tags=client_tags,
132133
roles=roles,
134+
timezone=timezone,
133135
)
134136
# mypy cannot follow module import
135137
if http_session is None:

0 commit comments

Comments
 (0)