Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from contextlib import contextmanager
from threading import RLock
from typing import TYPE_CHECKING, Any
from urllib.parse import quote_plus, urlencode

import jaydebeapi
import jpype
Expand Down Expand Up @@ -220,3 +221,38 @@ def get_autocommit(self, conn: jaydebeapi.Connection) -> bool:
with suppress_and_warn(jaydebeapi.Error, jpype.JException):
return conn.jconn.getAutoCommit()
return False

def get_uri(self) -> str:
"""Get the connection URI for the JDBC connection."""
conn = self.connection
extra = conn.extra_dejson

scheme = extra.get("sqlalchemy_scheme")
if not scheme:
return conn.host

driver = extra.get("sqlalchemy_driver")
uri_prefix = f"{scheme}+{driver}" if driver else scheme

auth_part = ""
if conn.login:
auth_part = quote_plus(conn.login)
if conn.password:
auth_part = f"{auth_part}:{quote_plus(conn.password)}"
auth_part = f"{auth_part}@"

host_part = conn.host or "localhost"
if conn.port:
host_part = f"{host_part}:{conn.port}"

schema_part = f"/{quote_plus(conn.schema)}" if conn.schema else ""

uri = f"{uri_prefix}://{auth_part}{host_part}{schema_part}"

sqlalchemy_query = extra.get("sqlalchemy_query", {})
if isinstance(sqlalchemy_query, dict):
query_string = urlencode({k: str(v) for k, v in sqlalchemy_query.items() if v is not None})
if query_string:
uri = f"{uri}?{query_string}"

return uri
67 changes: 67 additions & 0 deletions providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,70 @@ def call_get_conn():
future.result() # This will raise OSError if get_conn isn't threadsafe

assert mock_connect.call_count == 10

@pytest.mark.parametrize(
"params,expected_uri",
[
# JDBC URL fallback cases
pytest.param(
{"host": "jdbc:mysql://localhost:3306/test"},
"jdbc:mysql://localhost:3306/test",
id="jdbc-mysql",
),
pytest.param(
{"host": "jdbc:postgresql://localhost:5432/test?user=user&password=pass%40word"},
"jdbc:postgresql://localhost:5432/test?user=user&password=pass%40word",
id="jdbc-postgresql",
),
pytest.param(
{"host": "jdbc:oracle:thin:@localhost:1521:xe"},
"jdbc:oracle:thin:@localhost:1521:xe",
id="jdbc-oracle",
),
pytest.param(
{"host": "jdbc:sqlserver://localhost:1433;databaseName=test;trustServerCertificate=true"},
"jdbc:sqlserver://localhost:1433;databaseName=test;trustServerCertificate=true",
id="jdbc-sqlserver",
),
# SQLAlchemy URI cases
pytest.param(
{
"conn_params": {
"extra": json.dumps(
{"sqlalchemy_scheme": "mssql", "sqlalchemy_query": {"servicename": "test"}}
)
}
},
"mssql://login:password@host:1234/schema?servicename=test",
id="sqlalchemy-scheme-with-query",
),
pytest.param(
{
"conn_params": {
"extra": json.dumps(
{"sqlalchemy_scheme": "postgresql", "sqlalchemy_driver": "psycopg2"}
)
}
},
"postgresql+psycopg2://login:password@host:1234/schema",
id="sqlalchemy-scheme-with-driver",
),
pytest.param(
{
"login": "user@domain",
"password": "pass/word",
"schema": "my/db",
"conn_params": {"extra": json.dumps({"sqlalchemy_scheme": "mysql"})},
},
"mysql://user%40domain:pass%2Fword@host:1234/my%2Fdb",
id="sqlalchemy-with-encoding",
),
],
)
def test_get_uri(self, params, expected_uri):
"""Test get_uri with different configurations including JDBC URLs and SQLAlchemy URIs."""
valid_keys = {"host", "login", "password", "schema", "conn_params"}
hook_params = {key: params[key] for key in valid_keys & params.keys()}

jdbc_hook = get_hook(**hook_params)
assert jdbc_hook.get_uri() == expected_uri