Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ celerybeat.pid

# Environments
.env
.env.*
.venv
env/
venv/
Expand Down
3 changes: 3 additions & 0 deletions deepnote_toolkit/sql/jinjasql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ def render_jinja_sql_template(template, param_style=None):
Args:
template (str): The Jinja SQL template to render.
param_style (str, optional): The parameter style to use. Defaults to "pyformat".
Common styles: "qmark" (?), "format" (%s), "pyformat" (%(name)s)

Returns:
str: The rendered SQL query.
"""

escaped_template = _escape_jinja_template(template)

# Default to pyformat for backwards compatibility
# Note: Some databases like Trino require "qmark" or "format" style
jinja_sql = JinjaSql(
param_style=param_style if param_style is not None else "pyformat"
)
Expand Down
17 changes: 16 additions & 1 deletion deepnote_toolkit/sql/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ class ExecuteSqlError(Exception):
del sql_alchemy_dict["params"]["snowflake_private_key_passphrase"]

param_style = sql_alchemy_dict.get("param_style")

# Auto-detect param_style for databases that don't support pyformat default
if param_style is None:
url_obj = make_url(sql_alchemy_dict["url"])
# Mapping of SQLAlchemy dialect names to their required param_style
dialect_param_styles = {
"trino": "qmark", # Trino requires ? placeholders with list/tuple params
}
param_style = dialect_param_styles.get(url_obj.drivername)

skip_template_render = re.search(
"^snowflake.*host=.*.proxy.cloud.getdbt.com", sql_alchemy_dict["url"]
)
Expand Down Expand Up @@ -425,10 +435,15 @@ def _execute_sql_on_engine(engine, query, bind_params):
connection.connection if needs_raw_connection else connection
)

# pandas.read_sql_query expects params as tuple (not list) for qmark/format style
params_for_pandas = (
tuple(bind_params) if isinstance(bind_params, list) else bind_params
)

return pd.read_sql_query(
query,
con=connection_for_pandas,
params=bind_params,
params=params_for_pandas,
coerce_float=coerce_float,
)
except ResourceClosedError:
Expand Down
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ dev = [
"poetry-dynamic-versioning>=1.4.0,<2.0.0",
"twine>=6.1.0,<7.0.0",
"codespell>=2.3.0,<3.0.0",
"pytest-subtests>=0.15.0,<0.16.0"
"pytest-subtests>=0.15.0,<0.16.0",
"python-dotenv>=1.2.1,<2.0.0"
]
license-check = [
# Dependencies needed for license checking that aren't in main production dependencies
Expand Down
227 changes: 227 additions & 0 deletions tests/integration/test_trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import json
import os
from contextlib import contextmanager
from pathlib import Path
from unittest import mock
from urllib.parse import quote

import pandas as pd
import pytest
from dotenv import load_dotenv
from trino import dbapi
from trino.auth import BasicAuthentication

from deepnote_toolkit import env as dnenv
from deepnote_toolkit.sql.sql_execution import execute_sql


@contextmanager
def use_trino_sql_connection(connection_json, env_var_name="TEST_TRINO_CONNECTION"):
dnenv.set_env(env_var_name, connection_json)
try:
yield env_var_name
finally:
dnenv.unset_env(env_var_name)


@pytest.fixture(scope="module")
def trino_credentials():
env_path = Path(__file__).parent.parent.parent / ".env"

if env_path.exists():
load_dotenv(env_path)

host = os.getenv("TRINO_HOST")
port = os.getenv("TRINO_PORT", "8080")
user = os.getenv("TRINO_USER")
password = os.getenv("TRINO_PASSWORD")
catalog = os.getenv("TRINO_CATALOG", "system")
schema = os.getenv("TRINO_SCHEMA", "runtime")
http_scheme = os.getenv("TRINO_HTTP_SCHEME", "https")

if not host or not user:
pytest.skip(
"Trino credentials not found. "
"Please set TRINO_HOST and TRINO_USER in .env file"
)

return {
"host": host,
"port": int(port),
"user": user,
"password": password,
"catalog": catalog,
"schema": schema,
"http_scheme": http_scheme,
}


@pytest.fixture(scope="module")
def trino_connection(trino_credentials):
auth = None

if trino_credentials["password"]:
auth = BasicAuthentication(
trino_credentials["user"], trino_credentials["password"]
)

conn = dbapi.connect(
host=trino_credentials["host"],
port=trino_credentials["port"],
user=trino_credentials["user"],
auth=auth,
http_scheme=trino_credentials["http_scheme"],
catalog=trino_credentials["catalog"],
schema=trino_credentials["schema"],
)

try:
yield conn
finally:
conn.close()


class TestTrinoConnection:
"""Test Trino database connection."""

def test_connection_established(self, trino_connection):
"""Test that connection to Trino is established."""
cursor = trino_connection.cursor()
cursor.execute("SELECT 1")
result = cursor.fetchone()

assert result is not None
assert result[0] == 1

cursor.close()

def test_show_catalogs(self, trino_connection):
"""Test listing available catalogs."""
cursor = trino_connection.cursor()
cursor.execute("SHOW CATALOGS")
catalogs = cursor.fetchall()

assert len(catalogs) > 0
assert any("system" in str(catalog) for catalog in catalogs)

cursor.close()


@pytest.fixture
def trino_toolkit_connection(trino_credentials):
"""Create a Trino connection JSON for deepnote toolkit."""
username = quote(trino_credentials["user"], safe="")
password_part = (
f":{quote(trino_credentials['password'], safe='')}"
if trino_credentials["password"]
else ""
)
connection_url = (
f"trino://{username}{password_part}"
f"@{trino_credentials['host']}:{trino_credentials['port']}"
f"/{trino_credentials['catalog']}/{trino_credentials['schema']}"
)

# Trino uses `qmark` paramstyle (`?` placeholders with list/tuple params), not pyformat, which is the default
connection_json = json.dumps(
{
"url": connection_url,
"params": {},
"param_style": "qmark",
}
)

with use_trino_sql_connection(connection_json) as env_var_name:
yield env_var_name


class TestTrinoWithDeepnoteToolkit:
"""Test Trino connection using Toolkit's SQL execution."""

def test_execute_sql_simple_query(self, trino_toolkit_connection):
result = execute_sql(
template="SELECT 1 as test_value",
sql_alchemy_json_env_var=trino_toolkit_connection,
)

assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert "test_value" in result.columns
assert result["test_value"].iloc[0] == 1

def test_execute_sql_with_jinja_template(self, trino_toolkit_connection):
test_string = "test string"
test_number = 123

def mock_get_variable_value(variable_name):
variables = {
"test_string_var": test_string,
"test_number_var": test_number,
}
return variables[variable_name]

with mock.patch(
"deepnote_toolkit.sql.jinjasql_utils._get_variable_value",
side_effect=mock_get_variable_value,
):
result = execute_sql(
template="SELECT {{test_string_var}} as message, {{test_number_var}} as number",
sql_alchemy_json_env_var=trino_toolkit_connection,
)

assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert "message" in result.columns
assert "number" in result.columns
assert result["message"].iloc[0] == test_string
assert result["number"].iloc[0] == test_number

def test_execute_sql_with_autodetection(self, trino_credentials):
"""
Test execute_sql with auto-detection of param_style
(regression reported in BLU-5135)

This simulates the real-world scenario where the backend provides a connection
JSON without explicit param_style, and Toolkit must auto-detect it.
"""

username = quote(trino_credentials["user"], safe="")
password_part = (
f":{quote(trino_credentials['password'], safe='')}"
if trino_credentials["password"]
else ""
)
connection_url = (
f"trino://{username}{password_part}"
f"@{trino_credentials['host']}:{trino_credentials['port']}"
f"/{trino_credentials['catalog']}/{trino_credentials['schema']}"
)

connection_json = json.dumps(
{
"url": connection_url,
"params": {},
# NO param_style - should auto-detect to `qmark` for Trino
}
)

test_value = "test value"

with (
use_trino_sql_connection(
connection_json, "TEST_TRINO_AUTODETECT"
) as env_var_name,
mock.patch(
"deepnote_toolkit.sql.jinjasql_utils._get_variable_value",
return_value=test_value,
),
):
result = execute_sql(
template="SELECT {{test_var}} as detected",
sql_alchemy_json_env_var=env_var_name,
)

assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert "detected" in result.columns
assert result["detected"].iloc[0] == test_value
Loading
Loading