Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion providers/exasol/src/airflow/providers/exasol/hooks/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pyexasol
from deprecated import deprecated
from pyexasol import ExaConnection, ExaStatement
from sqlalchemy.engine import URL

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.sql.hooks.handlers import return_single_query_results
Expand Down Expand Up @@ -53,10 +54,12 @@ class ExasolHook(DbApiHook):
conn_type = "exasol"
hook_name = "Exasol"
supports_autocommit = True
DEFAULT_SQLALCHEMY_SCHEME = "exa+websocket" # sqlalchemy-exasol dialect

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args, sqlalchemy_scheme: str | None = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.schema = kwargs.pop("schema", None)
self._sqlalchemy_scheme = sqlalchemy_scheme

def get_conn(self) -> ExaConnection:
conn = self.get_connection(self.get_conn_id())
Expand All @@ -74,6 +77,46 @@ def get_conn(self) -> ExaConnection:
conn = pyexasol.connect(**conn_args)
return conn

@property
def sqlalchemy_scheme(self) -> str:
"""Sqlalchemy scheme either from constructor, connection extras or default."""
extra_scheme = self.connection is not None and self.connection_extra_lower.get("sqlalchemy_scheme")
sqlalchemy_scheme = self._sqlalchemy_scheme or extra_scheme or self.DEFAULT_SQLALCHEMY_SCHEME
if sqlalchemy_scheme not in ["exa+websocket", "exa+pyodbc", "exa+turbodbc"]:
raise ValueError(
f"sqlalchemy_scheme in connection extra should be one of 'exa+websocket', 'exa+pyodbc' or 'exa+turbodbc', "
f"but got '{sqlalchemy_scheme}'. See https://github.com/exasol/sqlalchemy-exasol?tab=readme-ov-file#using-sqlalchemy-with-exasol-db for more details."
)
return sqlalchemy_scheme

@property
def sqlalchemy_url(self) -> URL:
"""
Return a Sqlalchemy.engine.URL object from the connection.

:return: the extracted sqlalchemy.engine.URL object.
"""
connection = self.connection
query = connection.extra_dejson
query = {k: v for k, v in query.items() if k.lower() != "sqlalchemy_scheme"}
return URL.create(
drivername=self.sqlalchemy_scheme,
username=connection.login,
password=connection.password,
host=connection.host,
port=connection.port,
database=self.schema or connection.schema,
query=query,
)

def get_uri(self) -> str:
"""
Extract the URI from the connection.

:return: the extracted uri.
"""
return self.sqlalchemy_url.render_as_string(hide_password=False)

def _get_pandas_df(
self, sql, parameters: Iterable | Mapping[str, Any] | None = None, **kwargs
) -> pd.DataFrame:
Expand Down
85 changes: 85 additions & 0 deletions providers/exasol/tests/unit/exasol/hooks/test_exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,91 @@ def test_get_conn_extra_args(self, mock_pyexasol):
assert kwargs["encryption"] is True


class TestExasolHookSqlalchemy:
def get_connection(self, extra: dict | None = None) -> models.Connection:
return models.Connection(
login="login",
password="password",
host="host",
port=1234,
schema="schema",
extra=extra,
)

@pytest.mark.parametrize(
"init_scheme, extra_scheme, expected_result, expect_error",
[
(None, None, "exa+websocket", False),
("exa+pyodbc", None, "exa+pyodbc", False),
(None, "exa+turbodbc", "exa+turbodbc", False),
("exa+invalid", None, None, True),
(None, "exa+invalid", None, True),
],
ids=[
"default",
"from_init_arg",
"from_extra",
"invalid_from_init_arg",
"invalid_from_extra",
],
)
def test_sqlalchemy_scheme_property(self, init_scheme, extra_scheme, expected_result, expect_error):
hook = ExasolHook(sqlalchemy_scheme=init_scheme) if init_scheme else ExasolHook()
connection = self.get_connection(extra={"sqlalchemy_scheme": extra_scheme} if extra_scheme else None)
hook.get_connection = mock.Mock(return_value=connection)

if not expect_error:
assert hook.sqlalchemy_scheme == expected_result
else:
with pytest.raises(ValueError):
_ = hook.sqlalchemy_scheme

@pytest.mark.parametrize(
"hook_scheme, extra, expected_url",
[
(None, {}, "exa+websocket://login:password@host:1234/schema"),
(
None,
{"CONNECTIONLCALL": "en_US.UTF-8", "driver": "EXAODBC"},
"exa+websocket://login:password@host:1234/schema?CONNECTIONLCALL=en_US.UTF-8&driver=EXAODBC",
),
(
None,
{"sqlalchemy_scheme": "exa+turbodbc", "CONNECTIONLCALL": "en_US.UTF-8", "driver": "EXAODBC"},
"exa+turbodbc://login:password@host:1234/schema?CONNECTIONLCALL=en_US.UTF-8&driver=EXAODBC",
),
(
"exa+pyodbc",
{
"sqlalchemy_scheme": "exa+turbodbc", # should be overridden
"CONNECTIONLCALL": "en_US.UTF-8",
"driver": "EXAODBC",
},
"exa+pyodbc://login:password@host:1234/schema?CONNECTIONLCALL=en_US.UTF-8&driver=EXAODBC",
),
],
ids=[
"default",
"default_with_extra",
"scheme_from_extra_turbodbc",
"scheme_from_hook",
],
)
def test_sqlalchemy_url_property(self, hook_scheme, extra, expected_url):
hook = ExasolHook(sqlalchemy_scheme=hook_scheme) if hook_scheme else ExasolHook()
hook.get_connection = mock.Mock(return_value=self.get_connection(extra=extra))
assert hook.sqlalchemy_url.render_as_string(hide_password=False) == expected_url

def test_get_uri(self):
hook = ExasolHook()
connection = self.get_connection(extra={"CONNECTIONLCALL": "en_US.UTF-8", "driver": "EXAODBC"})
hook.get_connection = mock.Mock(return_value=connection)
assert (
hook.get_uri()
== "exa+websocket://login:password@host:1234/schema?CONNECTIONLCALL=en_US.UTF-8&driver=EXAODBC"
)


class TestExasolHook:
def setup_method(self):
self.cur = mock.MagicMock(rowcount=lambda: 0)
Expand Down