Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions providers/oracle/src/airflow/providers/oracle/hooks/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ def get_conn(self) -> oracledb.Connection:
that you are connecting to (CONNECT_DATA part of TNS)
:param sid: Oracle System ID that identifies a particular
database on a system
:param wallet_location: Specify the directory where the wallet can be found.
:param wallet_password: the password to use to decrypt the wallet, if it is encrypted.
For Oracle Autonomous Database this is the password created when downloading the wallet.
:param ssl_server_cert_dn: Specify the distinguished name (DN) which should be matched
with the server. This value is ignored if the ``ssl_server_dn_match`` parameter is not
set to the value True.
:param ssl_server_dn_match: Specify whether the server certificate distinguished name
(DN) should be matched in addition to the regular certificate verification that is performed.
:param cclass: the connection class to use for Database Resident Connection Pooling (DRCP).
:param pool_name: the name of the DRCP pool when using multi-pool DRCP with Oracle Database 23.4, or higher.

You can set these parameters in the extra fields of your connection
as in
Expand Down Expand Up @@ -221,6 +231,8 @@ def get_conn(self) -> oracledb.Connection:
if "events" in conn.extra_dejson:
conn_config["events"] = conn.extra_dejson.get("events")

# TODO: Replace mapping with oracledb.AuthMode enum once python-oracledb>=2.3
# mode = getattr(oracledb.AuthMode, conn.extra_dejson.get("mode", "").upper(), None)
mode = conn.extra_dejson.get("mode", "").lower()
if mode == "sysdba":
conn_config["mode"] = oracledb.AUTH_MODE_SYSDBA
Expand All @@ -237,6 +249,8 @@ def get_conn(self) -> oracledb.Connection:
elif mode == "sysrac":
conn_config["mode"] = oracledb.AUTH_MODE_SYSRAC

# TODO: Replace mapping with oracledb.Purity enum once python-oracledb>=2.3
# purity = getattr(oracledb.Purity, conn.extra_dejson.get("purity", "").upper(), None)
purity = conn.extra_dejson.get("purity", "").lower()
if purity == "new":
conn_config["purity"] = oracledb.PURITY_NEW
Expand All @@ -249,6 +263,18 @@ def get_conn(self) -> oracledb.Connection:
if expire_time:
conn_config["expire_time"] = expire_time

for name in [
"wallet_location",
"wallet_password",
"ssl_server_cert_dn",
"ssl_server_dn_match",
"cclass",
"pool_name",
]:
value = conn.extra_dejson.get(name)
if value is not None:
conn_config[name] = value

oracle_conn = oracledb.connect(**conn_config)
if mod is not None:
oracle_conn.module = mod
Expand Down
21 changes: 21 additions & 0 deletions providers/oracle/tests/unit/oracle/hooks/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,27 @@ def test_get_uri(self, mock_connect, connection_params, expected_uri):
uri = self.db_hook.get_uri()
assert uri == expected_uri

@mock.patch("airflow.providers.oracle.hooks.oracle.oracledb.connect")
def test_get_conn_with_various_params(self, mock_connect):
"""Verify wallet/SSL, connection class, and pool parameters
are passed to oracledb.connect."""
params = {
"wallet_location": "/tmp/wallet",
"wallet_password": "secret",
"ssl_server_cert_dn": "CN=dbserver,OU=DB,O=Oracle,L=BLR,C=IN",
"ssl_server_dn_match": True,
"cclass": "MY_APP_CLASS",
"pool_name": "POOL_1",
}
self.connection.extra = json.dumps(params)
self.db_hook.get_conn()

assert mock_connect.call_count == 1
_, kwargs = mock_connect.call_args

for key, value in params.items():
assert kwargs[key] == value


class TestOracleHook:
def setup_method(self):
Expand Down