Skip to content

Commit 5eddb88

Browse files
authored
fix: Trino SQL parameterization and add integration tests (#22)
* fix: Trino SQL parameterization and add integration tests * chore: Improve test coverage * fix: Lint * chore: Improve test coverage
1 parent 5bf9888 commit 5eddb88

File tree

7 files changed

+494
-2
lines changed

7 files changed

+494
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ celerybeat.pid
133133

134134
# Environments
135135
.env
136+
.env.*
136137
.venv
137138
env/
138139
venv/

deepnote_toolkit/sql/jinjasql_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@ def render_jinja_sql_template(template, param_style=None):
1414
Args:
1515
template (str): The Jinja SQL template to render.
1616
param_style (str, optional): The parameter style to use. Defaults to "pyformat".
17+
Common styles: "qmark" (?), "format" (%s), "pyformat" (%(name)s)
1718
1819
Returns:
1920
str: The rendered SQL query.
2021
"""
2122

2223
escaped_template = _escape_jinja_template(template)
2324

25+
# Default to pyformat for backwards compatibility
26+
# Note: Some databases like Trino require "qmark" or "format" style
2427
jinja_sql = JinjaSql(
2528
param_style=param_style if param_style is not None else "pyformat"
2629
)

deepnote_toolkit/sql/sql_execution.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ class ExecuteSqlError(Exception):
150150
del sql_alchemy_dict["params"]["snowflake_private_key_passphrase"]
151151

152152
param_style = sql_alchemy_dict.get("param_style")
153+
154+
# Auto-detect param_style for databases that don't support pyformat default
155+
if param_style is None:
156+
url_obj = make_url(sql_alchemy_dict["url"])
157+
# Mapping of SQLAlchemy dialect names to their required param_style
158+
dialect_param_styles = {
159+
"trino": "qmark", # Trino requires ? placeholders with list/tuple params
160+
}
161+
param_style = dialect_param_styles.get(url_obj.drivername)
162+
153163
skip_template_render = re.search(
154164
"^snowflake.*host=.*.proxy.cloud.getdbt.com", sql_alchemy_dict["url"]
155165
)
@@ -425,10 +435,15 @@ def _execute_sql_on_engine(engine, query, bind_params):
425435
connection.connection if needs_raw_connection else connection
426436
)
427437

438+
# pandas.read_sql_query expects params as tuple (not list) for qmark/format style
439+
params_for_pandas = (
440+
tuple(bind_params) if isinstance(bind_params, list) else bind_params
441+
)
442+
428443
return pd.read_sql_query(
429444
query,
430445
con=connection_for_pandas,
431-
params=bind_params,
446+
params=params_for_pandas,
432447
coerce_float=coerce_float,
433448
)
434449
except ResourceClosedError:

poetry.lock

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ dev = [
200200
"poetry-dynamic-versioning>=1.4.0,<2.0.0",
201201
"twine>=6.1.0,<7.0.0",
202202
"codespell>=2.3.0,<3.0.0",
203-
"pytest-subtests>=0.15.0,<0.16.0"
203+
"pytest-subtests>=0.15.0,<0.16.0",
204+
"python-dotenv>=1.2.1,<2.0.0"
204205
]
205206
license-check = [
206207
# Dependencies needed for license checking that aren't in main production dependencies

tests/integration/test_trino.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import json
2+
import os
3+
from contextlib import contextmanager
4+
from pathlib import Path
5+
from unittest import mock
6+
from urllib.parse import quote
7+
8+
import pandas as pd
9+
import pytest
10+
from dotenv import load_dotenv
11+
from trino import dbapi
12+
from trino.auth import BasicAuthentication
13+
14+
from deepnote_toolkit import env as dnenv
15+
from deepnote_toolkit.sql.sql_execution import execute_sql
16+
17+
18+
@contextmanager
19+
def use_trino_sql_connection(connection_json, env_var_name="TEST_TRINO_CONNECTION"):
20+
dnenv.set_env(env_var_name, connection_json)
21+
try:
22+
yield env_var_name
23+
finally:
24+
dnenv.unset_env(env_var_name)
25+
26+
27+
@pytest.fixture(scope="module")
28+
def trino_credentials():
29+
env_path = Path(__file__).parent.parent.parent / ".env"
30+
31+
if env_path.exists():
32+
load_dotenv(env_path)
33+
34+
host = os.getenv("TRINO_HOST")
35+
port = os.getenv("TRINO_PORT", "8080")
36+
user = os.getenv("TRINO_USER")
37+
password = os.getenv("TRINO_PASSWORD")
38+
catalog = os.getenv("TRINO_CATALOG", "system")
39+
schema = os.getenv("TRINO_SCHEMA", "runtime")
40+
http_scheme = os.getenv("TRINO_HTTP_SCHEME", "https")
41+
42+
if not host or not user:
43+
pytest.skip(
44+
"Trino credentials not found. "
45+
"Please set TRINO_HOST and TRINO_USER in .env file"
46+
)
47+
48+
return {
49+
"host": host,
50+
"port": int(port),
51+
"user": user,
52+
"password": password,
53+
"catalog": catalog,
54+
"schema": schema,
55+
"http_scheme": http_scheme,
56+
}
57+
58+
59+
@pytest.fixture(scope="module")
60+
def trino_connection(trino_credentials):
61+
auth = None
62+
63+
if trino_credentials["password"]:
64+
auth = BasicAuthentication(
65+
trino_credentials["user"], trino_credentials["password"]
66+
)
67+
68+
conn = dbapi.connect(
69+
host=trino_credentials["host"],
70+
port=trino_credentials["port"],
71+
user=trino_credentials["user"],
72+
auth=auth,
73+
http_scheme=trino_credentials["http_scheme"],
74+
catalog=trino_credentials["catalog"],
75+
schema=trino_credentials["schema"],
76+
)
77+
78+
try:
79+
yield conn
80+
finally:
81+
conn.close()
82+
83+
84+
class TestTrinoConnection:
85+
"""Test Trino database connection."""
86+
87+
def test_connection_established(self, trino_connection):
88+
"""Test that connection to Trino is established."""
89+
cursor = trino_connection.cursor()
90+
cursor.execute("SELECT 1")
91+
result = cursor.fetchone()
92+
93+
assert result is not None
94+
assert result[0] == 1
95+
96+
cursor.close()
97+
98+
def test_show_catalogs(self, trino_connection):
99+
"""Test listing available catalogs."""
100+
cursor = trino_connection.cursor()
101+
cursor.execute("SHOW CATALOGS")
102+
catalogs = cursor.fetchall()
103+
104+
assert len(catalogs) > 0
105+
assert any("system" in str(catalog) for catalog in catalogs)
106+
107+
cursor.close()
108+
109+
110+
@pytest.fixture
111+
def trino_toolkit_connection(trino_credentials):
112+
"""Create a Trino connection JSON for deepnote toolkit."""
113+
username = quote(trino_credentials["user"], safe="")
114+
password_part = (
115+
f":{quote(trino_credentials['password'], safe='')}"
116+
if trino_credentials["password"]
117+
else ""
118+
)
119+
connection_url = (
120+
f"trino://{username}{password_part}"
121+
f"@{trino_credentials['host']}:{trino_credentials['port']}"
122+
f"/{trino_credentials['catalog']}/{trino_credentials['schema']}"
123+
)
124+
125+
# Trino uses `qmark` paramstyle (`?` placeholders with list/tuple params), not pyformat, which is the default
126+
connection_json = json.dumps(
127+
{
128+
"url": connection_url,
129+
"params": {},
130+
"param_style": "qmark",
131+
}
132+
)
133+
134+
with use_trino_sql_connection(connection_json) as env_var_name:
135+
yield env_var_name
136+
137+
138+
class TestTrinoWithDeepnoteToolkit:
139+
"""Test Trino connection using Toolkit's SQL execution."""
140+
141+
def test_execute_sql_simple_query(self, trino_toolkit_connection):
142+
result = execute_sql(
143+
template="SELECT 1 as test_value",
144+
sql_alchemy_json_env_var=trino_toolkit_connection,
145+
)
146+
147+
assert isinstance(result, pd.DataFrame)
148+
assert len(result) == 1
149+
assert "test_value" in result.columns
150+
assert result["test_value"].iloc[0] == 1
151+
152+
def test_execute_sql_with_jinja_template(self, trino_toolkit_connection):
153+
test_string = "test string"
154+
test_number = 123
155+
156+
def mock_get_variable_value(variable_name):
157+
variables = {
158+
"test_string_var": test_string,
159+
"test_number_var": test_number,
160+
}
161+
return variables[variable_name]
162+
163+
with mock.patch(
164+
"deepnote_toolkit.sql.jinjasql_utils._get_variable_value",
165+
side_effect=mock_get_variable_value,
166+
):
167+
result = execute_sql(
168+
template="SELECT {{test_string_var}} as message, {{test_number_var}} as number",
169+
sql_alchemy_json_env_var=trino_toolkit_connection,
170+
)
171+
172+
assert isinstance(result, pd.DataFrame)
173+
assert len(result) == 1
174+
assert "message" in result.columns
175+
assert "number" in result.columns
176+
assert result["message"].iloc[0] == test_string
177+
assert result["number"].iloc[0] == test_number
178+
179+
def test_execute_sql_with_autodetection(self, trino_credentials):
180+
"""
181+
Test execute_sql with auto-detection of param_style
182+
(regression reported in BLU-5135)
183+
184+
This simulates the real-world scenario where the backend provides a connection
185+
JSON without explicit param_style, and Toolkit must auto-detect it.
186+
"""
187+
188+
username = quote(trino_credentials["user"], safe="")
189+
password_part = (
190+
f":{quote(trino_credentials['password'], safe='')}"
191+
if trino_credentials["password"]
192+
else ""
193+
)
194+
connection_url = (
195+
f"trino://{username}{password_part}"
196+
f"@{trino_credentials['host']}:{trino_credentials['port']}"
197+
f"/{trino_credentials['catalog']}/{trino_credentials['schema']}"
198+
)
199+
200+
connection_json = json.dumps(
201+
{
202+
"url": connection_url,
203+
"params": {},
204+
# NO param_style - should auto-detect to `qmark` for Trino
205+
}
206+
)
207+
208+
test_value = "test value"
209+
210+
with (
211+
use_trino_sql_connection(
212+
connection_json, "TEST_TRINO_AUTODETECT"
213+
) as env_var_name,
214+
mock.patch(
215+
"deepnote_toolkit.sql.jinjasql_utils._get_variable_value",
216+
return_value=test_value,
217+
),
218+
):
219+
result = execute_sql(
220+
template="SELECT {{test_var}} as detected",
221+
sql_alchemy_json_env_var=env_var_name,
222+
)
223+
224+
assert isinstance(result, pd.DataFrame)
225+
assert len(result) == 1
226+
assert "detected" in result.columns
227+
assert result["detected"].iloc[0] == test_value

0 commit comments

Comments
 (0)