Skip to content

Commit b3c7e06

Browse files
committed
Add Cursor.describe to retrieve the schema of a query
1 parent b020474 commit b3c7e06

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import trino
2121
from tests.integration.conftest import trino_version
2222
from trino import constants
23+
from trino.dbapi import DescribeOutput
2324
from trino.exceptions import NotSupportedError, TrinoQueryError, TrinoUserError
2425
from trino.transaction import IsolationLevel
2526

@@ -1107,3 +1108,69 @@ def test_prepared_statements(run_trino):
11071108
cur.execute('DEALLOCATE PREPARE test_prepared_statements')
11081109
cur.fetchall()
11091110
assert cur._request._client_session.prepared_statements == {}
1111+
1112+
1113+
def test_describe(run_trino):
1114+
_, host, port = run_trino
1115+
1116+
trino_connection = trino.dbapi.Connection(
1117+
host=host, port=port, user="test", catalog="tpch",
1118+
)
1119+
cur = trino_connection.cursor()
1120+
1121+
result = cur.describe("SELECT 1, DECIMAL '1.0' as a")
1122+
1123+
assert result == [
1124+
DescribeOutput(name='_col0', catalog='', schema='', table='', type='integer', type_size=4, aliased=False),
1125+
DescribeOutput(name='a', catalog='', schema='', table='', type='decimal(2,1)', type_size=8, aliased=True)
1126+
]
1127+
1128+
1129+
def test_describe_table_query(run_trino):
1130+
_, host, port = run_trino
1131+
1132+
trino_connection = trino.dbapi.Connection(
1133+
host=host, port=port, user="test", catalog="tpch",
1134+
)
1135+
cur = trino_connection.cursor()
1136+
1137+
result = cur.describe("SELECT * from tpch.tiny.nation")
1138+
1139+
assert result == [
1140+
DescribeOutput(
1141+
name='nationkey',
1142+
catalog='tpch',
1143+
schema='tiny',
1144+
table='nation',
1145+
type='bigint',
1146+
type_size=8,
1147+
aliased=False,
1148+
),
1149+
DescribeOutput(
1150+
name='name',
1151+
catalog='tpch',
1152+
schema='tiny',
1153+
table='nation',
1154+
type='varchar(25)',
1155+
type_size=0,
1156+
aliased=False,
1157+
),
1158+
DescribeOutput(
1159+
name='regionkey',
1160+
catalog='tpch',
1161+
schema='tiny',
1162+
table='nation',
1163+
type='bigint',
1164+
type_size=8,
1165+
aliased=False,
1166+
),
1167+
DescribeOutput(
1168+
name='comment',
1169+
catalog='tpch',
1170+
schema='tiny',
1171+
table='nation',
1172+
type='varchar(152)',
1173+
type_size=0,
1174+
aliased=False,
1175+
)
1176+
]

trino/dbapi.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import math
2222
import uuid
2323
from decimal import Decimal
24-
from typing import Any, Dict, List, Optional # NOQA for mypy types
24+
from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types
2525

2626
import trino.client
2727
import trino.exceptions
@@ -220,6 +220,20 @@ def cursor(self, experimental_python_types: bool = None):
220220
)
221221

222222

223+
class DescribeOutput(NamedTuple):
224+
name: str
225+
catalog: str
226+
schema: str
227+
table: str
228+
type: str
229+
type_size: int
230+
aliased: bool
231+
232+
@classmethod
233+
def from_row(cls, row: List[Any]):
234+
return cls(*row)
235+
236+
223237
class Cursor(object):
224238
"""Database cursor.
225239
@@ -517,6 +531,27 @@ def fetchmany(self, size=None) -> List[List[Any]]:
517531

518532
return result
519533

534+
def describe(self, sql: str) -> List[DescribeOutput]:
535+
"""
536+
Returns the schema of a given SQL statement
537+
538+
:param sql: SQL statement
539+
"""
540+
statement_name = self._generate_unique_statement_name()
541+
self._prepare_statement(sql, statement_name)
542+
try:
543+
sql = f"DESCRIBE OUTPUT {statement_name}"
544+
self._query = trino.client.TrinoQuery(
545+
self._request,
546+
sql=sql,
547+
experimental_python_types=self._experimental_pyton_types,
548+
)
549+
result = self._query.execute()
550+
finally:
551+
self._deallocate_prepared_statement(statement_name)
552+
553+
return list(map(lambda x: DescribeOutput.from_row(x), result))
554+
520555
def genall(self):
521556
return self._query.result
522557

0 commit comments

Comments
 (0)