Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dev/breeze/tests/test_selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1943,7 +1943,7 @@ def test_expected_output_push(
{
"selected-providers-list-as-string": "amazon common.compat common.io common.sql "
"databricks dbt.cloud ftp google microsoft.mssql mysql "
"openlineage postgres sftp snowflake trino",
"openlineage oracle postgres sftp snowflake trino",
"all-python-versions": f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
"all-python-versions-list-as-string": DEFAULT_PYTHON_MAJOR_MINOR_VERSION,
"ci-image-build": "true",
Expand All @@ -1954,7 +1954,7 @@ def test_expected_output_push(
"docs-build": "true",
"docs-list-as-string": "apache-airflow task-sdk amazon common.compat common.io common.sql "
"databricks dbt.cloud ftp google microsoft.mssql mysql "
"openlineage postgres sftp snowflake trino",
"openlineage oracle postgres sftp snowflake trino",
"skip-prek-hooks": ALL_SKIPPED_COMMITS_ON_NO_CI_IMAGE,
"run-kubernetes-tests": "false",
"upgrade-to-newer-dependencies": "false",
Expand All @@ -1964,7 +1964,7 @@ def test_expected_output_push(
{
"description": "amazon...google",
"test_types": "Providers[amazon] Providers[common.compat,common.io,common.sql,"
"databricks,dbt.cloud,ftp,microsoft.mssql,mysql,openlineage,"
"databricks,dbt.cloud,ftp,microsoft.mssql,mysql,openlineage,oracle,"
"postgres,sftp,snowflake,trino] Providers[google]",
}
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@
import calendar
from datetime import date, datetime, timedelta
from decimal import Decimal
from functools import cached_property
from typing import TYPE_CHECKING

import oracledb

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.oracle.hooks.oracle import OracleHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage


class OracleToGCSOperator(BaseSQLToGCSOperator):
"""
Expand Down Expand Up @@ -62,10 +67,13 @@ def __init__(self, *, oracle_conn_id="oracle_default", ensure_utc=False, **kwarg
self.ensure_utc = ensure_utc
self.oracle_conn_id = oracle_conn_id

@cached_property
def db_hook(self) -> OracleHook:
return OracleHook(oracle_conn_id=self.oracle_conn_id)

def query(self):
"""Query Oracle and returns a cursor to the results."""
oracle = OracleHook(oracle_conn_id=self.oracle_conn_id)
conn = oracle.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor()
if self.ensure_utc:
# Ensure TIMESTAMP results are in UTC
Expand Down Expand Up @@ -121,3 +129,20 @@ def convert_type(self, value, schema_type, **kwargs):
else:
value = base64.standard_b64encode(value).decode("ascii")
return value

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = get_openlineage_facets_with_sql(
hook=self.db_hook,
sql=self.sql,
conn_id=self.oracle_conn_id,
database=self.db_hook.service_name or self.db_hook.sid,
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,17 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock

import oracledb
import pytest

from airflow.models import Connection
from airflow.providers.common.compat.openlineage.facet import (
OutputDataset,
SchemaDatasetFacetFields,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.cloud.transfers.oracle_to_gcs import OracleToGCSOperator

TASK_ID = "test-oracle-to-gcs"
Expand Down Expand Up @@ -141,3 +149,71 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):

# once for the file and once for the schema
assert gcs_hook_mock.upload.call_count == 2

@pytest.mark.parametrize(
"input_service_name, input_sid, connection_port, default_port, expected_port",
[
("ServiceName", None, None, 1521, 1521),
(None, "SID", None, 1521, 1521),
(None, "SID", 1234, None, 1234),
(None, "SID", 1234, 1521, 1234),
],
)
def test_execute_openlineage_events(
self, input_service_name, input_sid, connection_port, default_port, expected_port
):
class DBApiHookForTests(DbApiHook):
conn_name_attr = "sql_default"
get_conn = MagicMock(name="conn")
get_connection = MagicMock()
service_name = input_service_name
sid = input_sid

def get_openlineage_database_info(self, connection):
from airflow.providers.openlineage.sqlparser import DatabaseInfo

return DatabaseInfo(
scheme="oracle",
authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port),
)

dbapi_hook = DBApiHookForTests()

class OracleToGCSOperatorForTest(OracleToGCSOperator):
@property
def db_hook(self):
return dbapi_hook

sql = """SELECT employee_id, first_name FROM hr.employees"""
op = OracleToGCSOperatorForTest(task_id=TASK_ID, sql=sql, bucket="bucket", filename="dir/file{}.csv")
DB_SCHEMA_NAME = "HR"
rows = [
(DB_SCHEMA_NAME, "employees", "employee_id", 1, "NUMBER"),
(DB_SCHEMA_NAME, "employees", "first_name", 2, "VARCHAR2"),
(DB_SCHEMA_NAME, "employees", "age", 3, "NUMBER"),
]
dbapi_hook.get_connection.return_value = Connection(
conn_id="sql_default", conn_type="oracle", host="host", port=connection_port
)
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]

lineage = op.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 1
assert lineage.inputs[0].namespace == f"oracle://host:{expected_port}"
assert lineage.inputs[0].name == f"{input_service_name or input_sid}.HR.employees"
assert len(lineage.inputs[0].facets) == 1
assert lineage.inputs[0].facets["schema"].fields == [
SchemaDatasetFacetFields(name="employee_id", type="NUMBER"),
SchemaDatasetFacetFields(name="first_name", type="VARCHAR2"),
SchemaDatasetFacetFields(name="age", type="NUMBER"),
]
assert lineage.outputs == [
OutputDataset(
namespace="gs://bucket",
name="dir",
)
]

assert len(lineage.job_facets) == 1
assert lineage.job_facets["sql"].query == sql
assert lineage.run_facets == {}
4 changes: 4 additions & 0 deletions providers/oracle/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,17 @@ dependencies = [
"numpy>=1.26.0; python_version=='3.12'",
"numpy>=2.1.0; python_version>='3.13'",
]
"openlineage" = [
"apache-airflow-providers-openlineage"
]

[dependency-groups]
dev = [
"apache-airflow",
"apache-airflow-task-sdk",
"apache-airflow-devel-common",
"apache-airflow-providers-common-sql",
"apache-airflow-providers-openlineage",
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"numpy>=1.22.4; python_version<'3.11'",
"numpy>=1.23.2; python_version=='3.11'",
Expand Down
47 changes: 46 additions & 1 deletion providers/oracle/src/airflow/providers/oracle/hooks/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
import math
import warnings
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING, Any

import oracledb

if TYPE_CHECKING:
from airflow.models.connection import Connection
from airflow.providers.openlineage.sqlparser import DatabaseInfo

from airflow.providers.common.sql.hooks.sql import DbApiHook

DEFAULT_DB_PORT = 1521
Expand Down Expand Up @@ -116,6 +120,20 @@ def __init__(
self.thick_mode_config_dir = thick_mode_config_dir
self.fetch_decimals = fetch_decimals
self.fetch_lobs = fetch_lobs
self._service_name: str | None = None
self._sid: str | None = None

@property
def service_name(self) -> str | None:
if self._service_name is None:
self._service_name = self.get_connection(self.get_conn_id()).extra_dejson.get("service_name")
return self._service_name

@property
def sid(self) -> str | None:
if self._sid is None:
self._sid = self.get_connection(self.get_conn_id()).extra_dejson.get("sid")
return self._sid

def get_conn(self) -> oracledb.Connection:
"""
Expand Down Expand Up @@ -448,6 +466,33 @@ def handler(cursor):

return result

def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
"""Return Oracle specific information for OpenLineage."""
from airflow.providers.openlineage.sqlparser import DatabaseInfo

return DatabaseInfo(
scheme=self.get_openlineage_database_dialect(connection),
authority=DbApiHook.get_openlineage_authority_part(connection, default_port=DEFAULT_DB_PORT),
information_schema_table_name="ALL_TAB_COLUMNS",
information_schema_columns=[
"owner",
"table_name",
"column_name",
"column_id",
"data_type",
],
database=self.service_name or self.sid,
normalize_name_method=lambda name: name.upper(),
)

def get_openlineage_database_dialect(self, _) -> str:
"""Return database dialect."""
return "oracle"

def get_openlineage_default_schema(self) -> str | None:
"""Return current schema."""
return self.get_first("SELECT SYS_CONTEXT('USERENV', 'CURRENT_SCHEMA') FROM dual")[0]

def get_uri(self) -> str:
"""Get the URI for the Oracle connection."""
conn = self.get_connection(self.get_conn_id())
Expand Down
40 changes: 40 additions & 0 deletions providers/oracle/tests/unit/oracle/hooks/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,43 @@ def test_test_connection_use_dual_table(self):
self.cur.execute.assert_called_once_with("select 1 from dual")
assert status is True
assert message == "Connection successfully tested"

def test_get_openlineage_database_info_with_service_name(self):
conn = Connection(
conn_id="oracle_default",
conn_type="oracle",
host="localhost",
port=1521,
extra='{"service_name": "ORCLPDB1"}',
)
hook = OracleHook(oracle_conn_id="oracle_default")
hook.get_connection = lambda _: conn

assert hook.service_name == "ORCLPDB1"
db_info = hook.get_openlineage_database_info(conn)
assert db_info.scheme == "oracle"
assert db_info.authority == "localhost:1521"
assert db_info.database == "ORCLPDB1"
assert db_info.normalize_name_method("employees") == "EMPLOYEES"
assert db_info.information_schema_table_name == "ALL_TAB_COLUMNS"
assert "owner" in db_info.information_schema_columns

def test_get_openlineage_database_info_with_sid(self):
conn = Connection(
conn_id="oracle_default",
conn_type="oracle",
host="dbhost",
port=1521,
extra='{"sid": "XE"}',
)
hook = OracleHook(oracle_conn_id="oracle_default")
hook.get_connection = lambda _: conn

assert hook.sid == "XE"
db_info = hook.get_openlineage_database_info(conn)
assert db_info.scheme == "oracle"
assert db_info.authority == "dbhost:1521"
assert db_info.database == "XE"
assert db_info.normalize_name_method("employees") == "EMPLOYEES"
assert db_info.information_schema_table_name == "ALL_TAB_COLUMNS"
assert "owner" in db_info.information_schema_columns