1010# See the License for the specific language governing permissions and
1111# limitations under the License.
1212import math
13- from datetime import datetime
13+ from datetime import datetime , time
14+ from decimal import Decimal
1415
1516import pytest
1617import pytz
@@ -123,22 +124,114 @@ def test_string_query_param(trino_connection):
123124 assert rows [0 ][0 ] == "six'"
124125
125126
127+ def test_decimal_query_param (trino_connection ):
128+ cur = trino_connection .cursor ()
129+
130+ cur .execute ("SELECT ?" , params = (Decimal ('0.142857' ),))
131+ rows = cur .fetchall ()
132+
133+ assert rows [0 ][0 ] == Decimal ('0.142857' )
134+
135+
126136def test_datetime_query_param (trino_connection ):
127137 cur = trino_connection .cursor ()
128138
129- cur .execute ("SELECT ?" , params = (datetime (2020 , 1 , 1 , 0 , 0 , 0 ),))
139+ params = datetime (2020 , 1 , 1 , 16 , 43 , 22 , 320000 )
140+
141+ cur .execute ("SELECT ?" , params = (params ,))
130142 rows = cur .fetchall ()
131143
132- assert rows [0 ][0 ] == "2020-01-01 00:00:00.000"
144+ assert rows [0 ][0 ] == params
145+ assert cur .description [0 ][1 ] == "timestamp"
146+
147+
148+ def test_datetime_with_time_zone_query_param (trino_connection ):
149+ cur = trino_connection .cursor ()
150+
151+ params = datetime (2020 , 1 , 1 , 16 , 43 , 22 , 320000 , tzinfo = pytz .timezone ('CET' ))
133152
134153 cur .execute ("SELECT ?" ,
135- params = (datetime ( 2020 , 1 , 1 , 0 , 0 , 0 , tzinfo = pytz . utc ) ,))
154+ params = (params ,))
136155 rows = cur .fetchall ()
137156
138- assert rows [0 ][0 ] == "2020-01-01 00:00:00.000 UTC"
157+ assert rows [0 ][0 ] == params
139158 assert cur .description [0 ][1 ] == "timestamp with time zone"
140159
141160
161+ def test_datetime_with_time_zone_numeric_offset (trino_connection ):
162+ cur = trino_connection .cursor ()
163+
164+ cur .execute ("SELECT TIMESTAMP '2001-08-22 03:04:05.321 -08:00'" )
165+ rows = cur .fetchall ()
166+
167+ assert rows [0 ][0 ] == datetime .strptime ("2001-08-22 03:04:05.321 -08:00" , "%Y-%m-%d %H:%M:%S.%f %z" )
168+
169+
170+ def test_special_datetimes_query_param (trino_connection ):
171+ cur = trino_connection .cursor ()
172+
173+ for special_date in (
174+ datetime .fromtimestamp (1603589478 , pytz .timezone ('Europe/Warsaw' )),
175+ ):
176+ params = special_date
177+
178+ cur .execute ("SELECT ?" , params = (params ,))
179+ rows = cur .fetchall ()
180+
181+ assert rows [0 ][0 ] == params
182+
183+
184+ def test_date_query_param (trino_connection ):
185+ cur = trino_connection .cursor ()
186+
187+ params = datetime (2020 , 1 , 1 , 0 , 0 , 0 ).date ()
188+
189+ cur .execute ("SELECT ?" , params = (params ,))
190+ rows = cur .fetchall ()
191+
192+ assert rows [0 ][0 ] == params
193+
194+
195+ def test_special_dates_query_param (trino_connection ):
196+ cur = trino_connection .cursor ()
197+
198+ for params in (
199+ # datetime(-1, 1, 1, 0, 0, 0).date(),
200+ # datetime(0, 1, 1, 0, 0, 0).date(),
201+ datetime (1752 , 9 , 4 , 0 , 0 , 0 ).date (),
202+ datetime (1970 , 1 , 1 , 0 , 0 , 0 ).date (),
203+ ):
204+ cur .execute ("SELECT ?" , params = (params ,))
205+ rows = cur .fetchall ()
206+
207+ assert rows [0 ][0 ] == params
208+
209+
210+ def test_time_query_param (trino_connection ):
211+ cur = trino_connection .cursor ()
212+
213+ params = time (12 , 3 , 44 , 333000 )
214+
215+ cur .execute ("SELECT ?" , params = (params ,))
216+ rows = cur .fetchall ()
217+
218+ assert rows [0 ][0 ] == params
219+
220+
221+ @pytest .mark .skip (reason = "time with time zone currently not supported" )
222+ def test_time_with_time_zone_query_param (trino_connection ):
223+ cur = trino_connection .cursor ()
224+
225+ params = time (16 , 43 , 22 , 320000 , tzinfo = pytz .timezone ('CET' ))
226+
227+ cur .execute ("SELECT ?" ,
228+ params = (params ,))
229+ rows = cur .fetchall ()
230+
231+ assert rows [0 ][0 ] == params
232+ assert cur .description [0 ][1 ] == "time with time zone"
233+
234+
142235def test_array_query_param (trino_connection ):
143236 cur = trino_connection .cursor ()
144237
@@ -158,6 +251,38 @@ def test_array_query_param(trino_connection):
158251 assert rows [0 ][0 ] == "array(integer)"
159252
160253
254+ def test_array_timestamp_query_param (trino_connection ):
255+ cur = trino_connection .cursor ()
256+
257+ params = [datetime (2020 , 1 , 1 , 0 , 0 , 0 ), datetime (2020 , 1 , 2 , 0 , 0 , 0 )]
258+
259+ cur .execute ("SELECT ?" , params = (params ,))
260+ rows = cur .fetchall ()
261+
262+ assert rows [0 ][0 ] == params
263+
264+ cur .execute ("SELECT TYPEOF(?)" , params = (params ,))
265+ rows = cur .fetchall ()
266+
267+ assert rows [0 ][0 ] == "array(timestamp(6))"
268+
269+
270+ def test_array_timestamp_with_timezone_query_param (trino_connection ):
271+ cur = trino_connection .cursor ()
272+
273+ params = [datetime (2020 , 1 , 1 , 0 , 0 , 0 , tzinfo = pytz .utc ), datetime (2020 , 1 , 2 , 0 , 0 , 0 , tzinfo = pytz .utc )]
274+
275+ cur .execute ("SELECT ?" , params = (params ,))
276+ rows = cur .fetchall ()
277+
278+ assert rows [0 ][0 ] == params
279+
280+ cur .execute ("SELECT TYPEOF(?)" , params = (params ,))
281+ rows = cur .fetchall ()
282+
283+ assert rows [0 ][0 ] == "array(timestamp(6) with time zone)"
284+
285+
161286def test_dict_query_param (trino_connection ):
162287 cur = trino_connection .cursor ()
163288
0 commit comments