Skip to content

Commit 7972aab

Browse files
committed
Add support for TIMEZONE
1 parent f97aea6 commit 7972aab

File tree

8 files changed

+106
-2
lines changed

8 files changed

+106
-2
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ repos:
1313
additional_dependencies:
1414
- "types-pytz"
1515
- "types-requests"
16+
- "types-tzlocal"

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,20 @@ conn = trino.dbapi.connect(
357357
)
358358
```
359359

360+
## Timezone
361+
362+
Sets the time zone for the session using the time zone name. Defaults to the timezone set on your workstation.
363+
364+
```python
365+
import trino
366+
conn = trino.dbapi.connect(
367+
host='localhost',
368+
port=443,
369+
user='username',
370+
timezone="Europe/Brussels",
371+
)
372+
```
373+
360374
## SSL
361375

362376
### SSL verification

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
"Topic :: Database :: Front-Ends",
7777
],
7878
python_requires='>=3.7',
79-
install_requires=["pytz", "requests"],
79+
install_requires=["pytz", "requests", "tzlocal"],
8080
extras_require={
8181
"all": all_require,
8282
"kerberos": kerberos_require,

tests/integration/test_dbapi_integration.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717
import pytz
1818
import requests
19+
from tzlocal import get_localzone_name
1920

2021
import trino
2122
from tests.integration.conftest import trino_version
@@ -1120,3 +1121,27 @@ def test_prepared_statements(run_trino):
11201121
cur.execute('DEALLOCATE PREPARE test_prepared_statements')
11211122
cur.fetchall()
11221123
assert cur._request._client_session.prepared_statements == {}
1124+
1125+
1126+
def test_set_timezone_in_connection(run_trino):
1127+
_, host, port = run_trino
1128+
1129+
trino_connection = trino.dbapi.Connection(
1130+
host=host, port=port, user="test", catalog="tpch", timezone="Europe/Brussels"
1131+
)
1132+
cur = trino_connection.cursor()
1133+
cur.execute('SELECT current_timezone()')
1134+
res = cur.fetchall()
1135+
assert res[0][0] == "Europe/Brussels"
1136+
1137+
1138+
def test_connection_without_timezone(run_trino):
1139+
_, host, port = run_trino
1140+
1141+
trino_connection = trino.dbapi.Connection(
1142+
host=host, port=port, user="test", catalog="tpch"
1143+
)
1144+
cur = trino_connection.cursor()
1145+
cur.execute('SELECT current_timezone()')
1146+
res = cur.fetchall()
1147+
assert res[0][0] == get_localzone_name()

tests/unit/test_client.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import threading
1414
import time
1515
import uuid
16+
from tzlocal import get_localzone_name
1617
from typing import Optional, Dict
1718
from unittest import mock
1819
from urllib.parse import urlparse
@@ -65,6 +66,7 @@ def test_request_headers(mock_get_and_post):
6566
schema = "test_schema"
6667
user = "test_user"
6768
source = "test_source"
69+
timezone = "Europe/Brussels"
6870
accept_encoding_header = "accept-encoding"
6971
accept_encoding_value = "identity,deflate,gzip"
7072
client_info_header = constants.HEADER_CLIENT_INFO
@@ -78,6 +80,7 @@ def test_request_headers(mock_get_and_post):
7880
source=source,
7981
catalog=catalog,
8082
schema=schema,
83+
timezone=timezone,
8184
headers={
8285
accept_encoding_header: accept_encoding_value,
8386
client_info_header: client_info_value,
@@ -93,9 +96,10 @@ def assert_headers(headers):
9396
assert headers[constants.HEADER_SOURCE] == source
9497
assert headers[constants.HEADER_USER] == user
9598
assert headers[constants.HEADER_SESSION] == ""
99+
assert headers[constants.HEADER_TIMEZONE] == timezone
96100
assert headers[accept_encoding_header] == accept_encoding_value
97101
assert headers[client_info_header] == client_info_value
98-
assert len(headers.keys()) == 8
102+
assert len(headers.keys()) == 9
99103

100104
req.post("URL")
101105
_, post_kwargs = post.call_args
@@ -1056,3 +1060,49 @@ def test_request_headers_role_empty(mock_get_and_post):
10561060
req.get("URL")
10571061
_, get_kwargs = get.call_args
10581062
assert_headers_with_roles(post_kwargs["headers"], None)
1063+
1064+
1065+
def assert_headers_timezone(headers: Dict[str, str], timezone: str):
1066+
assert headers[constants.HEADER_TIMEZONE] == timezone
1067+
1068+
1069+
def test_request_headers_with_timezone(mock_get_and_post):
1070+
get, post = mock_get_and_post
1071+
1072+
req = TrinoRequest(
1073+
host="coordinator",
1074+
port=8080,
1075+
client_session=ClientSession(
1076+
user="test_user",
1077+
timezone="Europe/Brussels"
1078+
),
1079+
)
1080+
1081+
req.post("URL")
1082+
_, post_kwargs = post.call_args
1083+
assert_headers_timezone(post_kwargs["headers"], "Europe/Brussels")
1084+
1085+
req.get("URL")
1086+
_, get_kwargs = get.call_args
1087+
assert_headers_timezone(post_kwargs["headers"], "Europe/Brussels")
1088+
1089+
1090+
def test_request_headers_without_timezone(mock_get_and_post):
1091+
get, post = mock_get_and_post
1092+
1093+
req = TrinoRequest(
1094+
host="coordinator",
1095+
port=8080,
1096+
client_session=ClientSession(
1097+
user="test_user",
1098+
),
1099+
)
1100+
localzone = get_localzone_name()
1101+
1102+
req.post("URL")
1103+
_, post_kwargs = post.call_args
1104+
assert_headers_timezone(post_kwargs["headers"], localzone)
1105+
1106+
req.get("URL")
1107+
_, get_kwargs = get.call_args
1108+
assert_headers_timezone(post_kwargs["headers"], localzone)

trino/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import urllib.parse
4444
from datetime import datetime, timedelta, timezone
4545
from decimal import Decimal
46+
from tzlocal import get_localzone_name
4647
from typing import Any, Dict, List, Optional, Tuple, Union
4748

4849
import pytz
@@ -100,6 +101,8 @@ class ClientSession(object):
100101
:param client_tags: Client tags as list of strings.
101102
:param roles: roles for the current session. Some connectors do not
102103
support role management. See connector documentation for more details.
104+
:param timezone: The timezone for query processing. Defaults to the timezone
105+
of the Trino cluster, and not the timezone of the client.
103106
"""
104107

105108
def __init__(
@@ -114,6 +117,7 @@ def __init__(
114117
extra_credential: List[Tuple[str, str]] = None,
115118
client_tags: List[str] = None,
116119
roles: Dict[str, str] = None,
120+
timezone: str = None,
117121
):
118122
self._user = user
119123
self._catalog = catalog
@@ -127,6 +131,7 @@ def __init__(
127131
self._roles = roles.copy() if roles is not None else {}
128132
self._prepared_statements: Dict[str, str] = {}
129133
self._object_lock = threading.Lock()
134+
self._timezone = timezone or get_localzone_name()
130135

131136
@property
132137
def user(self):
@@ -207,6 +212,11 @@ def prepared_statements(self, prepared_statements):
207212
with self._object_lock:
208213
self._prepared_statements = prepared_statements
209214

215+
@property
216+
def timezone(self):
217+
with self._object_lock:
218+
return self._timezone
219+
210220
def __getstate__(self):
211221
state = self.__dict__.copy()
212222
del state["_object_lock"]
@@ -408,6 +418,7 @@ def http_headers(self) -> Dict[str, str]:
408418
headers[constants.HEADER_SCHEMA] = self._client_session.schema
409419
headers[constants.HEADER_SOURCE] = self._client_session.source
410420
headers[constants.HEADER_USER] = self._client_session.user
421+
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
411422
if len(self._client_session.roles.values()):
412423
headers[constants.HEADER_ROLE] = ",".join(
413424
# ``name`` must not contain ``=``

trino/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
HEADER_CLIENT_INFO = "X-Trino-Client-Info"
3535
HEADER_CLIENT_TAGS = "X-Trino-Client-Tags"
3636
HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential"
37+
HEADER_TIMEZONE = "X-Trino-Time-Zone"
3738

3839
HEADER_SESSION = "X-Trino-Session"
3940
HEADER_SET_SESSION = "X-Trino-Set-Session"

trino/dbapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
client_tags=None,
112112
experimental_python_types=False,
113113
roles=None,
114+
timezone=None,
114115
):
115116
self.host = host
116117
self.port = port
@@ -130,6 +131,7 @@ def __init__(
130131
extra_credential=extra_credential,
131132
client_tags=client_tags,
132133
roles=roles,
134+
timezone=timezone,
133135
)
134136
# mypy cannot follow module import
135137
if http_session is None:

0 commit comments

Comments
 (0)