Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openlineage, gcs: add openlineage methods for GcsToGcsOperator #31350

Merged
merged 1 commit into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions airflow/providers/google/cloud/transfers/gcs_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def __init__(
self.source_object_required = source_object_required
self.exact_match = exact_match
self.match_glob = match_glob
self.resolved_source_objects: set[str] = set()
self.resolved_target_objects: set[str] = set()

def execute(self, context: Context):

Expand Down Expand Up @@ -540,7 +542,34 @@ def _copy_single_object(self, hook, source_object, destination_object):
destination_object,
)

self.resolved_source_objects.add(source_object)
if not destination_object:
self.resolved_target_objects.add(source_object)
else:
self.resolved_target_objects.add(destination_object)

hook.rewrite(self.source_bucket, source_object, self.destination_bucket, destination_object)

if self.move_object:
hook.delete(self.source_bucket, source_object)

def get_openlineage_events_on_complete(self, task_instance):
"""
Implementing _on_complete because execute method does preprocessing on internals.
This means we won't have to normalize self.source_object and self.source_objects,
destination bucket and so on.
"""
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

return OperatorLineage(
inputs=[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not required @mobuchowski ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's required here

Dataset(namespace=f"gs://{self.source_bucket}", name=source)
for source in sorted(self.resolved_source_objects)
],
outputs=[
Dataset(namespace=f"gs://{self.destination_bucket}", name=target)
for target in sorted(self.resolved_target_objects)
],
)
18 changes: 10 additions & 8 deletions airflow/providers/openlineage/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def get_operator_classnames(cls) -> list[str]:
return []

def extract(self) -> OperatorLineage | None:
# OpenLineage methods are optional - if there's no method, return None
try:
return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore
except AttributeError:
Expand All @@ -100,19 +101,20 @@ def extract_on_complete(self, task_instance) -> OperatorLineage | None:

def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | None:
try:
facets = get_facets_method(*args)
facets: OperatorLineage = get_facets_method(*args)
# "rewrite" OperatorLineage to safeguard against different version of the same class
# that was existing in openlineage-airflow package outside of Airflow repo
return OperatorLineage(
inputs=facets.inputs,
outputs=facets.outputs,
run_facets=facets.run_facets,
job_facets=facets.job_facets,
)
except ImportError:
self.log.exception(
"OpenLineage provider method failed to import OpenLineage integration. "
"This should not happen."
)
except Exception:
self.log.exception("OpenLineage provider method failed to extract data from provider. ")
else:
return OperatorLineage(
inputs=facets.inputs,
outputs=facets.outputs,
run_facets=facets.run_facets,
job_facets=facets.job_facets,
)
return None
12 changes: 7 additions & 5 deletions dev/breeze/tests/test_selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def test_expected_output_full_tests_needed(
{
"affected-providers-list-as-string": "amazon apache.beam apache.cassandra cncf.kubernetes "
"common.sql facebook google hashicorp microsoft.azure microsoft.mssql "
"mysql oracle postgres presto salesforce sftp ssh trino",
"mysql openlineage oracle postgres presto salesforce sftp ssh trino",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"needs-helm-tests": "false",
Expand All @@ -564,8 +564,8 @@ def test_expected_output_full_tests_needed(
{
"affected-providers-list-as-string": "amazon apache.beam apache.cassandra "
"cncf.kubernetes common.sql facebook google "
"hashicorp microsoft.azure microsoft.mssql mysql oracle postgres presto "
"salesforce sftp ssh trino",
"hashicorp microsoft.azure microsoft.mssql mysql openlineage oracle postgres "
"presto salesforce sftp ssh trino",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"image-build": "true",
Expand Down Expand Up @@ -666,7 +666,7 @@ def test_expected_output_pull_request_v2_3(
"affected-providers-list-as-string": "amazon apache.beam apache.cassandra "
"cncf.kubernetes common.sql "
"facebook google hashicorp microsoft.azure microsoft.mssql mysql "
"oracle postgres presto salesforce sftp ssh trino",
"openlineage oracle postgres presto salesforce sftp ssh trino",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"image-build": "true",
Expand All @@ -685,6 +685,7 @@ def test_expected_output_pull_request_v2_3(
"--package-filter apache-airflow-providers-microsoft-azure "
"--package-filter apache-airflow-providers-microsoft-mssql "
"--package-filter apache-airflow-providers-mysql "
"--package-filter apache-airflow-providers-openlineage "
"--package-filter apache-airflow-providers-oracle "
"--package-filter apache-airflow-providers-postgres "
"--package-filter apache-airflow-providers-presto "
Expand All @@ -697,7 +698,7 @@ def test_expected_output_pull_request_v2_3(
"skip-provider-tests": "false",
"parallel-test-types-list-as-string": "Providers[amazon] Always CLI "
"Providers[apache.beam,apache.cassandra,cncf.kubernetes,common.sql,facebook,"
"hashicorp,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto,"
"hashicorp,microsoft.azure,microsoft.mssql,mysql,openlineage,oracle,postgres,presto,"
"salesforce,sftp,ssh,trino] Providers[google]",
},
id="CLI tests and Google-related provider tests should run if cli/chart files changed",
Expand Down Expand Up @@ -965,6 +966,7 @@ def test_upgrade_to_newer_dependencies(files: tuple[str, ...], expected_outputs:
"--package-filter apache-airflow-providers-microsoft-azure "
"--package-filter apache-airflow-providers-microsoft-mssql "
"--package-filter apache-airflow-providers-mysql "
"--package-filter apache-airflow-providers-openlineage "
"--package-filter apache-airflow-providers-oracle "
"--package-filter apache-airflow-providers-postgres "
"--package-filter apache-airflow-providers-presto "
Expand Down
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@
"microsoft.azure",
"microsoft.mssql",
"mysql",
"openlineage",
"oracle",
"postgres",
"presto",
Expand Down
73 changes: 73 additions & 0 deletions tests/providers/google/cloud/transfers/test_gcs_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from unittest import mock

import pytest
from openlineage.client.run import Dataset

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.transfers.gcs_to_gcs import WILDCARD, GCSToGCSOperator
Expand Down Expand Up @@ -827,3 +828,75 @@ def test_copy_files_into_a_folder(
for src, dst in zip(expected_source_objects, expected_destination_objects)
]
mock_hook.return_value.rewrite.assert_has_calls(mock_calls)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
def test_execute_simple_reports_openlineage(self, mock_hook):
operator = GCSToGCSOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_object=SOURCE_OBJECTS_SINGLE_FILE[0],
destination_bucket=DESTINATION_BUCKET,
)

operator.execute(None)

lineage = operator.get_openlineage_events_on_complete(None)
assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0] == Dataset(
namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_SINGLE_FILE[0]
)
assert lineage.outputs[0] == Dataset(
namespace=f"gs://{DESTINATION_BUCKET}", name=SOURCE_OBJECTS_SINGLE_FILE[0]
)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
def test_execute_multiple_reports_openlineage(self, mock_hook):
operator = GCSToGCSOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_objects=SOURCE_OBJECTS_LIST,
destination_bucket=DESTINATION_BUCKET,
destination_object=DESTINATION_OBJECT,
)

operator.execute(None)

lineage = operator.get_openlineage_events_on_complete(None)
assert len(lineage.inputs) == 3
assert len(lineage.outputs) == 1
assert lineage.inputs == [
Dataset(namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_LIST[0]),
Dataset(namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_LIST[1]),
Dataset(namespace=f"gs://{TEST_BUCKET}", name=SOURCE_OBJECTS_LIST[2]),
]
assert lineage.outputs[0] == Dataset(namespace=f"gs://{DESTINATION_BUCKET}", name=DESTINATION_OBJECT)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSHook")
def test_execute_wildcard_reports_openlineage(self, mock_hook):
mock_hook.return_value.list.return_value = [
"test_object1.txt",
"test_object2.txt",
]

operator = GCSToGCSOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_object=SOURCE_OBJECT_WILDCARD_SUFFIX,
destination_bucket=DESTINATION_BUCKET,
destination_object=DESTINATION_OBJECT,
)

operator.execute(None)

lineage = operator.get_openlineage_events_on_complete(None)
assert len(lineage.inputs) == 2
assert len(lineage.outputs) == 2
assert lineage.inputs == [
Dataset(namespace=f"gs://{TEST_BUCKET}", name="test_object1.txt"),
Dataset(namespace=f"gs://{TEST_BUCKET}", name="test_object2.txt"),
]
assert lineage.outputs == [
Dataset(namespace=f"gs://{DESTINATION_BUCKET}", name="foo/bar/1.txt"),
Dataset(namespace=f"gs://{DESTINATION_BUCKET}", name="foo/bar/2.txt"),
]