1313import threading
1414import time
1515import uuid
16+ from tzlocal import get_localzone_name
1617from typing import Optional , Dict
1718from unittest import mock
1819from urllib .parse import urlparse
3233from trino .client import TrinoQuery , TrinoRequest , TrinoResult , ClientSession , _DelayExponential , _retry_with , \
3334 _RetryWithExponentialBackoff
3435
36+ try :
37+ from zoneinfo ._common import ZoneInfoNotFoundError # type: ignore
38+
39+ except ModuleNotFoundError :
40+ from backports .zoneinfo ._common import ZoneInfoNotFoundError # type: ignore
41+
3542
3643@mock .patch ("trino.client.TrinoRequest.http" )
3744def test_trino_initial_request (mock_requests , sample_post_response_data ):
@@ -65,6 +72,7 @@ def test_request_headers(mock_get_and_post):
6572 schema = "test_schema"
6673 user = "test_user"
6774 source = "test_source"
75+ timezone = "Europe/Brussels"
6876 accept_encoding_header = "accept-encoding"
6977 accept_encoding_value = "identity,deflate,gzip"
7078 client_info_header = constants .HEADER_CLIENT_INFO
@@ -78,6 +86,7 @@ def test_request_headers(mock_get_and_post):
7886 source = source ,
7987 catalog = catalog ,
8088 schema = schema ,
89+ timezone = timezone ,
8190 headers = {
8291 accept_encoding_header : accept_encoding_value ,
8392 client_info_header : client_info_value ,
@@ -93,9 +102,10 @@ def assert_headers(headers):
93102 assert headers [constants .HEADER_SOURCE ] == source
94103 assert headers [constants .HEADER_USER ] == user
95104 assert headers [constants .HEADER_SESSION ] == ""
105+ assert headers [constants .HEADER_TIMEZONE ] == timezone
96106 assert headers [accept_encoding_header ] == accept_encoding_value
97107 assert headers [client_info_header ] == client_info_value
98- assert len (headers .keys ()) == 8
108+ assert len (headers .keys ()) == 9
99109
100110 req .post ("URL" )
101111 _ , post_kwargs = post .call_args
@@ -1056,3 +1066,62 @@ def test_request_headers_role_empty(mock_get_and_post):
10561066 req .get ("URL" )
10571067 _ , get_kwargs = get .call_args
10581068 assert_headers_with_roles (post_kwargs ["headers" ], None )
1069+
1070+
1071+ def assert_headers_timezone (headers : Dict [str , str ], timezone : str ):
1072+ assert headers [constants .HEADER_TIMEZONE ] == timezone
1073+
1074+
1075+ def test_request_headers_with_timezone (mock_get_and_post ):
1076+ get , post = mock_get_and_post
1077+
1078+ req = TrinoRequest (
1079+ host = "coordinator" ,
1080+ port = 8080 ,
1081+ client_session = ClientSession (
1082+ user = "test_user" ,
1083+ timezone = "Europe/Brussels"
1084+ ),
1085+ )
1086+
1087+ req .post ("URL" )
1088+ _ , post_kwargs = post .call_args
1089+ assert_headers_timezone (post_kwargs ["headers" ], "Europe/Brussels" )
1090+
1091+ req .get ("URL" )
1092+ _ , get_kwargs = get .call_args
1093+ assert_headers_timezone (post_kwargs ["headers" ], "Europe/Brussels" )
1094+
1095+
1096+ def test_request_headers_without_timezone (mock_get_and_post ):
1097+ get , post = mock_get_and_post
1098+
1099+ req = TrinoRequest (
1100+ host = "coordinator" ,
1101+ port = 8080 ,
1102+ client_session = ClientSession (
1103+ user = "test_user" ,
1104+ ),
1105+ )
1106+ localzone = get_localzone_name ()
1107+
1108+ req .post ("URL" )
1109+ _ , post_kwargs = post .call_args
1110+ assert_headers_timezone (post_kwargs ["headers" ], localzone )
1111+
1112+ req .get ("URL" )
1113+ _ , get_kwargs = get .call_args
1114+ assert_headers_timezone (post_kwargs ["headers" ], localzone )
1115+
1116+
1117+ def test_request_with_invalid_timezone (mock_get_and_post ):
1118+ with pytest .raises (ZoneInfoNotFoundError ) as zinfo_error :
1119+ TrinoRequest (
1120+ host = "coordinator" ,
1121+ port = 8080 ,
1122+ client_session = ClientSession (
1123+ user = "test_user" ,
1124+ timezone = "INVALID_TIMEZONE"
1125+ ),
1126+ )
1127+ assert str (zinfo_error .value ).startswith ("'No time zone found with key" )
0 commit comments