Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@

import warnings
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage


class BigQueryToMySqlOperator(BigQueryToSqlBaseOperator):
"""
Expand Down Expand Up @@ -76,5 +82,69 @@ def __init__(
)
self.mysql_conn_id = mysql_conn_id

def get_sql_hook(self) -> MySqlHook:
@cached_property
def mysql_hook(self) -> MySqlHook:
return MySqlHook(schema=self.database, mysql_conn_id=self.mysql_conn_id)

def get_sql_hook(self) -> MySqlHook:
return self.mysql_hook

def execute(self, context):
# Set source_project_dataset_table here, after hooks are initialized and project_id is available
project_id = self.bigquery_hook.project_id
self.source_project_dataset_table = f"{project_id}.{self.dataset_id}.{self.table_id}"
return super().execute(context)

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import (
BIGQUERY_NAMESPACE,
get_facets_from_bq_table_for_given_fields,
get_identity_column_lineage_facet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

if not self.bigquery_hook:
self.bigquery_hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

try:
table_obj = self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
except Exception:
self.log.debug(
"OpenLineage: could not fetch BigQuery table %s",
self.source_project_dataset_table,
exc_info=True,
)
return OperatorLineage()

if self.selected_fields:
if isinstance(self.selected_fields, str):
bigquery_field_names = list(self.selected_fields)
else:
bigquery_field_names = self.selected_fields
else:
bigquery_field_names = [f.name for f in getattr(table_obj, "schema", [])]

input_dataset = Dataset(
namespace=BIGQUERY_NAMESPACE,
name=self.source_project_dataset_table,
facets=get_facets_from_bq_table_for_given_fields(table_obj, bigquery_field_names),
)

db_info = self.mysql_hook.get_openlineage_database_info(self.mysql_hook.get_conn())
namespace = f"{db_info.scheme}://{db_info.authority}"

output_name = f"{self.database}.{self.target_table_name}"

column_lineage_facet = get_identity_column_lineage_facet(
bigquery_field_names, input_datasets=[input_dataset]
)

output_facets = column_lineage_facet or {}
output_dataset = Dataset(namespace=namespace, name=output_name, facets=output_facets)

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,33 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock

from airflow.providers.google.cloud.transfers.bigquery_to_mysql import BigQueryToMySqlOperator

TASK_ID = "test-bq-create-table-operator"
TEST_DATASET = "test-dataset"
TEST_TABLE_ID = "test-table-id"
TEST_DAG_ID = "test-bigquery-operators"
TEST_PROJECT = "test-project"


def _make_bq_table(schema_names: list[str]):
class TableObj:
def __init__(self, schema):
self.schema = []
for n in schema:
field = MagicMock()
field.name = n
self.schema.append(field)
self.description = "table description"
self.external_data_configuration = None
self.labels = {}
self.num_rows = 0
self.num_bytes = 0
self.table_type = "TABLE"

return TableObj(schema_names)


class TestBigQueryToMySqlOperator:
Expand All @@ -46,3 +66,89 @@ def test_execute_good_request_to_bq(self, mock_hook):
selected_fields=None,
start_index=0,
)

@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mysql.MySqlHook")
def test_get_openlineage_facets_on_complete_no_selected_fields(self, mock_mysql_hook, mock_bq_hook):
mock_bq_client = MagicMock()
mock_bq_client.get_table.return_value = _make_bq_table(["id", "name", "value"])
mock_bq_hook.get_client.return_value = mock_bq_client
mock_bq_hook.return_value = mock_bq_hook

db_info = MagicMock(scheme="mysql", authority="localhost:3306", database="mydb")
mock_mysql_hook.get_openlineage_database_info.return_value = db_info
mock_mysql_hook.return_value = mock_mysql_hook

op = BigQueryToMySqlOperator(
task_id=TASK_ID,
dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
target_table_name="destination",
selected_fields=None,
database="mydb",
)
op.bigquery_hook = mock_bq_hook
op.bigquery_hook.project_id = TEST_PROJECT
op.mysql_hook = mock_mysql_hook
context = mock.MagicMock()
op.execute(context=context)

result = op.get_openlineage_facets_on_complete(None)
assert len(result.inputs) == 1
assert len(result.outputs) == 1

input_ds = result.inputs[0]
assert input_ds.namespace == "bigquery"
assert input_ds.name == f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}"
assert "schema" in input_ds.facets
schema_fields = [f.name for f in input_ds.facets["schema"].fields]
assert set(schema_fields) == {"id", "name", "value"}

output_ds = result.outputs[0]
assert output_ds.namespace == "mysql://localhost:3306"
assert output_ds.name == "mydb.destination"
assert "columnLineage" in output_ds.facets
col_lineage = output_ds.facets["columnLineage"]
assert set(col_lineage.fields.keys()) == {"id", "name", "value"}

@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mysql.MySqlHook")
def test_get_openlineage_facets_on_complete_selected_fields(self, mock_mysql_hook, mock_bq_hook):
mock_bq_client = MagicMock()
mock_bq_client.get_table.return_value = _make_bq_table(["id", "name", "value"])
mock_bq_hook.get_client.return_value = mock_bq_client
mock_bq_hook.return_value = mock_bq_hook

db_info = MagicMock(scheme="mysql", authority="localhost:3306", database="mydb")
mock_mysql_hook.get_openlineage_database_info.return_value = db_info
mock_mysql_hook.return_value = mock_mysql_hook

op = BigQueryToMySqlOperator(
task_id=TASK_ID,
dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}",
target_table_name="destination",
selected_fields=["id", "name"],
database="mydb",
)
op.bigquery_hook = mock_bq_hook
op.bigquery_hook.project_id = TEST_PROJECT
op.mysql_hook = mock_mysql_hook
context = mock.MagicMock()
op.execute(context=context)

result = op.get_openlineage_facets_on_complete(None)
assert len(result.inputs) == 1
assert len(result.outputs) == 1

input_ds = result.inputs[0]
assert input_ds.namespace == "bigquery"
assert input_ds.name == f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}"
assert "schema" in input_ds.facets
schema_fields = [f.name for f in input_ds.facets["schema"].fields]
assert set(schema_fields) == {"id", "name"}

output_ds = result.outputs[0]
assert output_ds.namespace == "mysql://localhost:3306"
assert output_ds.name == "mydb.destination"
assert "columnLineage" in output_ds.facets
col_lineage = output_ds.facets["columnLineage"]
assert set(col_lineage.fields.keys()) == {"id", "name"}