1010# See the License for the specific language governing permissions and
1111# limitations under the License.
1212import math
13- from datetime import datetime
13+ from datetime import datetime , time
14+ from decimal import Decimal
1415
1516import pytest
1617import pytz
@@ -123,22 +124,105 @@ def test_string_query_param(trino_connection):
123124 assert rows [0 ][0 ] == "six'"
124125
125126
127+ def test_decimal_query_param (trino_connection ):
128+ cur = trino_connection .cursor ()
129+
130+ cur .execute ("SELECT ?" , params = (Decimal ('0.142857' ),))
131+ rows = cur .fetchall ()
132+
133+ assert rows [0 ][0 ] == Decimal ('0.142857' )
134+
135+
126136def test_datetime_query_param (trino_connection ):
127137 cur = trino_connection .cursor ()
128138
129- cur .execute ("SELECT ?" , params = (datetime (2020 , 1 , 1 , 0 , 0 , 0 ),))
139+ params = datetime (2020 , 1 , 1 , 16 , 43 , 22 , 320000 )
140+
141+ cur .execute ("SELECT ?" , params = (params ,))
130142 rows = cur .fetchall ()
131143
132- assert rows [0 ][0 ] == "2020-01-01 00:00:00.000"
144+ assert rows [0 ][0 ] == params
145+ assert cur .description [0 ][1 ] == "timestamp"
146+
147+
148+ def test_datetime_with_time_zone_query_param (trino_connection ):
149+ cur = trino_connection .cursor ()
150+
151+ params = datetime (2020 , 1 , 1 , 16 , 43 , 22 , 320000 , tzinfo = pytz .timezone ('CET' ))
133152
134153 cur .execute ("SELECT ?" ,
135- params = (datetime ( 2020 , 1 , 1 , 0 , 0 , 0 , tzinfo = pytz . utc ) ,))
154+ params = (params ,))
136155 rows = cur .fetchall ()
137156
138- assert rows [0 ][0 ] == "2020-01-01 00:00:00.000 UTC"
157+ assert rows [0 ][0 ] == params
139158 assert cur .description [0 ][1 ] == "timestamp with time zone"
140159
141160
161+ def test_special_datetimes_query_param (trino_connection ):
162+ cur = trino_connection .cursor ()
163+
164+ for special_date in (
165+ datetime .fromtimestamp (1603589478 , pytz .timezone ('Europe/Warsaw' )),
166+ ):
167+ params = special_date
168+
169+ cur .execute ("SELECT ?" , params = (params ,))
170+ rows = cur .fetchall ()
171+
172+ assert rows [0 ][0 ] == params
173+
174+
175+ def test_date_query_param (trino_connection ):
176+ cur = trino_connection .cursor ()
177+
178+ params = datetime (2020 , 1 , 1 , 0 , 0 , 0 ).date ()
179+
180+ cur .execute ("SELECT ?" , params = (params ,))
181+ rows = cur .fetchall ()
182+
183+ assert rows [0 ][0 ] == params
184+
185+
186+ def test_special_dates_query_param (trino_connection ):
187+ cur = trino_connection .cursor ()
188+
189+ for params in (
190+ # datetime(-1, 1, 1, 0, 0, 0).date(),
191+ # datetime(0, 1, 1, 0, 0, 0).date(),
192+ datetime (1752 , 9 , 4 , 0 , 0 , 0 ).date (),
193+ datetime (1970 , 1 , 1 , 0 , 0 , 0 ).date (),
194+ ):
195+ cur .execute ("SELECT ?" , params = (params ,))
196+ rows = cur .fetchall ()
197+
198+ assert rows [0 ][0 ] == params
199+
200+
201+ def test_time_query_param (trino_connection ):
202+ cur = trino_connection .cursor ()
203+
204+ params = time (12 , 3 , 44 , 333000 )
205+
206+ cur .execute ("SELECT ?" , params = (params ,))
207+ rows = cur .fetchall ()
208+
209+ assert rows [0 ][0 ] == params
210+
211+
212+ @pytest .mark .skip (reason = "time with time zone currently not supported" )
213+ def test_time_with_time_zone_query_param (trino_connection ):
214+ cur = trino_connection .cursor ()
215+
216+ params = time (16 , 43 , 22 , 320000 , tzinfo = pytz .timezone ('CET' ))
217+
218+ cur .execute ("SELECT ?" ,
219+ params = (params ,))
220+ rows = cur .fetchall ()
221+
222+ assert rows [0 ][0 ] == params
223+ assert cur .description [0 ][1 ] == "time with time zone"
224+
225+
142226def test_array_query_param (trino_connection ):
143227 cur = trino_connection .cursor ()
144228
@@ -158,6 +242,38 @@ def test_array_query_param(trino_connection):
158242 assert rows [0 ][0 ] == "array(integer)"
159243
160244
245+ def test_array_timestamp_query_param (trino_connection ):
246+ cur = trino_connection .cursor ()
247+
248+ params = [datetime (2020 , 1 , 1 , 0 , 0 , 0 ), datetime (2020 , 1 , 2 , 0 , 0 , 0 )]
249+
250+ cur .execute ("SELECT ?" , params = (params ,))
251+ rows = cur .fetchall ()
252+
253+ assert rows [0 ][0 ] == params
254+
255+ cur .execute ("SELECT TYPEOF(?)" , params = (params ,))
256+ rows = cur .fetchall ()
257+
258+ assert rows [0 ][0 ] == "array(timestamp(6))"
259+
260+
261+ def test_array_timestamp_with_timezone_query_param (trino_connection ):
262+ cur = trino_connection .cursor ()
263+
264+ params = [datetime (2020 , 1 , 1 , 0 , 0 , 0 , tzinfo = pytz .utc ), datetime (2020 , 1 , 2 , 0 , 0 , 0 , tzinfo = pytz .utc )]
265+
266+ cur .execute ("SELECT ?" , params = (params ,))
267+ rows = cur .fetchall ()
268+
269+ assert rows [0 ][0 ] == params
270+
271+ cur .execute ("SELECT TYPEOF(?)" , params = (params ,))
272+ rows = cur .fetchall ()
273+
274+ assert rows [0 ][0 ] == "array(timestamp(6) with time zone)"
275+
276+
161277def test_dict_query_param (trino_connection ):
162278 cur = trino_connection .cursor ()
163279
0 commit comments