1111# limitations under the License.
1212import math
1313from datetime import datetime
14+ from decimal import Decimal
1415
1516import pytest
1617import pytz
@@ -123,19 +124,37 @@ def test_string_query_param(trino_connection):
123124 assert rows [0 ][0 ] == "six'"
124125
125126
127+ def test_float_query_param (trino_connection ):
128+ cur = trino_connection .cursor ()
129+
130+ cur .execute ("SELECT ?" , params = (1.23 ,))
131+ rows = cur .fetchall ()
132+
133+ assert rows [0 ][0 ] == 1.23
134+
135+
136+ def test_decimal_query_param (trino_connection ):
137+ cur = trino_connection .cursor ()
138+
139+ cur .execute ("SELECT ?" , params = (Decimal ('0.142857' ),))
140+ rows = cur .fetchall ()
141+
142+ assert rows [0 ][0 ] == Decimal ('0.142857' )
143+
144+
126145def test_datetime_query_param (trino_connection ):
127146 cur = trino_connection .cursor ()
128147
129148 cur .execute ("SELECT ?" , params = (datetime (2020 , 1 , 1 , 0 , 0 , 0 ),))
130149 rows = cur .fetchall ()
131150
132- assert rows [0 ][0 ] == " 2020-01-01 00:00:00.000"
151+ assert rows [0 ][0 ] == datetime ( 2020 , 1 , 1 , 0 , 0 , 0 )
133152
134153 cur .execute ("SELECT ?" ,
135154 params = (datetime (2020 , 1 , 1 , 0 , 0 , 0 , tzinfo = pytz .utc ),))
136155 rows = cur .fetchall ()
137156
138- assert rows [0 ][0 ] == " 2020-01-01 00:00:00.000 UTC"
157+ assert rows [0 ][0 ] == datetime ( 2020 , 1 , 1 , 0 , 0 , 0 , tzinfo = pytz . utc )
139158 assert cur .description [0 ][1 ] == "timestamp with time zone"
140159
141160
@@ -158,6 +177,33 @@ def test_array_query_param(trino_connection):
158177 assert rows [0 ][0 ] == "array(integer)"
159178
160179
180+ def test_array_timestamp_query_param (trino_connection ):
181+ cur = trino_connection .cursor ()
182+ cur .execute ("SELECT ?" , params = ([datetime (2020 , 1 , 1 , 0 , 0 , 0 ), datetime (2020 , 1 , 2 , 0 , 0 , 0 )],))
183+ rows = cur .fetchall ()
184+
185+ assert rows [0 ][0 ] == [datetime (2020 , 1 , 1 , 0 , 0 , 0 ), datetime (2020 , 1 , 2 , 0 , 0 , 0 )]
186+
187+ cur .execute ("SELECT TYPEOF(?)" , params = ([datetime (2020 , 1 , 1 , 0 , 0 , 0 ), datetime (2020 , 1 , 2 , 0 , 0 , 0 )],))
188+ rows = cur .fetchall ()
189+
190+ assert rows [0 ][0 ] == "array(timestamp(6))"
191+
192+
193+ def test_array_timestamp_with_timezone_query_param (trino_connection ):
194+ cur = trino_connection .cursor ()
195+
196+ cur .execute ("SELECT ?" , params = ([datetime (2020 , 1 , 1 , 0 , 0 , 0 , tzinfo = pytz .utc ), datetime (2020 , 1 , 2 , 0 , 0 , 0 , tzinfo = pytz .utc )],))
197+ rows = cur .fetchall ()
198+
199+ assert rows [0 ][0 ] == [datetime (2020 , 1 , 1 , 0 , 0 , 0 , tzinfo = pytz .utc ), datetime (2020 , 1 , 2 , 0 , 0 , 0 , tzinfo = pytz .utc )]
200+
201+ cur .execute ("SELECT TYPEOF(?)" , params = ([datetime (2020 , 1 , 1 , 0 , 0 , 0 , tzinfo = pytz .utc ), datetime (2020 , 1 , 2 , 0 , 0 , 0 , tzinfo = pytz .utc )],))
202+ rows = cur .fetchall ()
203+
204+ assert rows [0 ][0 ] == "array(timestamp(6) with time zone)"
205+
206+
161207def test_dict_query_param (trino_connection ):
162208 cur = trino_connection .cursor ()
163209
0 commit comments