Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dev/breeze/tests/test_selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,7 @@ def test_expected_output_push(
"airflow/datasets/",
),
{
"selected-providers-list-as-string": "amazon common.compat common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp snowflake trino",
"selected-providers-list-as-string": "amazon common.compat common.io common.sql dbt.cloud ftp google microsoft.mssql mysql openlineage postgres sftp snowflake trino",
"all-python-versions": "['3.9']",
"all-python-versions-list-as-string": "3.9",
"ci-image-build": "true",
Expand All @@ -1762,13 +1762,13 @@ def test_expected_output_push(
"skip-providers-tests": "false",
"test-groups": "['core', 'providers']",
"docs-build": "true",
"docs-list-as-string": "apache-airflow amazon common.compat common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp snowflake trino",
"docs-list-as-string": "apache-airflow amazon common.compat common.io common.sql dbt.cloud ftp google microsoft.mssql mysql openlineage postgres sftp snowflake trino",
"skip-pre-commits": "check-provider-yaml-valid,flynt,identity,lint-helm-chart,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,mypy-task-sdk,"
"ts-compile-format-lint-ui,ts-compile-format-lint-www",
"run-kubernetes-tests": "false",
"upgrade-to-newer-dependencies": "false",
"core-test-types-list-as-string": "API Always CLI Core Operators Other Serialization WWW",
"providers-test-types-list-as-string": "Providers[amazon] Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,mysql,openlineage,postgres,sftp,snowflake,trino] Providers[google]",
"providers-test-types-list-as-string": "Providers[amazon] Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,microsoft.mssql,mysql,openlineage,postgres,sftp,snowflake,trino] Providers[google]",
"needs-mypy": "false",
"mypy-checks": "[]",
},
Expand Down
3 changes: 2 additions & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,8 @@
"devel-deps": [],
"plugins": [],
"cross-providers-deps": [
"common.sql"
"common.sql",
"openlineage"
],
"excluded-python-versions": [],
"state": "ready"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@
import datetime
import decimal
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage


class MSSQLToGCSOperator(BaseSQLToGCSOperator):
"""
Expand Down Expand Up @@ -75,14 +80,17 @@ def __init__(
self.mssql_conn_id = mssql_conn_id
self.bit_fields = bit_fields or []

@cached_property
def db_hook(self) -> MsSqlHook:
return MsSqlHook(mssql_conn_id=self.mssql_conn_id)

def query(self):
"""
Query MSSQL and returns a cursor of results.

:return: mssql cursor
"""
mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id)
conn = mssql.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor()
cursor.execute(self.sql)
return cursor
Expand All @@ -109,3 +117,20 @@ def convert_type(cls, value, schema_type, **kwargs):
if isinstance(value, (datetime.date, datetime.time)):
return value.isoformat()
return value

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = get_openlineage_facets_with_sql(
hook=self.db_hook,
sql=self.sql,
conn_id=self.mssql_conn_id,
database=None,
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
28 changes: 28 additions & 0 deletions providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

if TYPE_CHECKING:
from airflow.providers.common.sql.dialects.dialect import Dialect
from airflow.providers.openlineage.sqlparser import DatabaseInfo


class MsSqlHook(DbApiHook):
Expand Down Expand Up @@ -117,3 +118,30 @@ def set_autocommit(

def get_autocommit(self, conn: PymssqlConnection):
return conn.autocommit_state

def get_openlineage_database_info(self, connection) -> DatabaseInfo:
"""Return MSSQL specific information for OpenLineage."""
from airflow.providers.openlineage.sqlparser import DatabaseInfo

return DatabaseInfo(
scheme=self.get_openlineage_database_dialect(connection),
authority=DbApiHook.get_openlineage_authority_part(connection, default_port=1433),
information_schema_columns=[
"table_schema",
"table_name",
"column_name",
"ordinal_position",
"data_type",
"table_catalog",
],
database=self.schema or self.connection.schema,
is_information_schema_cross_db=True,
)

def get_openlineage_database_dialect(self, connection) -> str:
"""Return database dialect."""
return "mssql"

def get_openlineage_default_schema(self) -> str | None:
"""Return current schema."""
return self.get_first("SELECT SCHEMA_NAME();")[0]
65 changes: 65 additions & 0 deletions providers/tests/google/cloud/transfers/test_mssql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@

import pytest

from airflow.models import Connection
from airflow.providers.common.compat.openlineage.facet import (
OutputDataset,
SchemaDatasetFacetFields,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.cloud.transfers.mssql_to_gcs import MSSQLToGCSOperator

TASK_ID = "test-mssql-to-gcs"
Expand Down Expand Up @@ -188,3 +194,62 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):

# once for the file and once for the schema
assert gcs_hook_mock.upload.call_count == 2

@pytest.mark.parametrize(
"connection_port, default_port, expected_port",
[(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)],
)
def test_execute_openlineage_events(self, connection_port, default_port, expected_port):
class DBApiHookForTests(DbApiHook):
conn_name_attr = "sql_default"
get_conn = mock.MagicMock(name="conn")
get_connection = mock.MagicMock()

def get_openlineage_database_info(self, connection):
from airflow.providers.openlineage.sqlparser import DatabaseInfo

return DatabaseInfo(
scheme="sqlscheme",
authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port),
)

dbapi_hook = DBApiHookForTests()

class MSSQLToGCSOperatorForTest(MSSQLToGCSOperator):
@property
def db_hook(self):
return dbapi_hook

sql = """SELECT a,b,c from my_db.my_table"""
op = MSSQLToGCSOperatorForTest(task_id=TASK_ID, sql=sql, bucket="bucket", filename="dir/file{}.csv")
DB_SCHEMA_NAME = "PUBLIC"
rows = [
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"),
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"),
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"),
]
dbapi_hook.get_connection.return_value = Connection(
conn_id="sql_default", conn_type="mssql", host="host", port=connection_port
)
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]

lineage = op.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 1
assert lineage.inputs[0].namespace == f"sqlscheme://host:{expected_port}"
assert lineage.inputs[0].name == "PUBLIC.popular_orders_day_of_week"
assert len(lineage.inputs[0].facets) == 1
assert lineage.inputs[0].facets["schema"].fields == [
SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"),
SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"),
SchemaDatasetFacetFields(name="orders_placed", type="int4"),
]
assert lineage.outputs == [
OutputDataset(
namespace="gs://bucket",
name="dir",
)
]

assert len(lineage.job_facets) == 1
assert lineage.job_facets["sql"].query == sql
assert lineage.run_facets == {}