Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenLineage support to GCS operators #35838

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 85 additions & 10 deletions airflow/providers/google/cloud/operators/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def __init__(
)
raise ValueError(err_message)

self._objects: list[str] = []
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
Expand All @@ -322,13 +323,47 @@ def execute(self, context: Context) -> None:
)

if self.objects is not None:
objects = self.objects
self._objects = self.objects
else:
objects = hook.list(bucket_name=self.bucket_name, prefix=self.prefix)
self.log.info("Deleting %s objects from %s", len(objects), self.bucket_name)
for object_name in objects:
self._objects = hook.list(bucket_name=self.bucket_name, prefix=self.prefix)
self.log.info("Deleting %s objects from %s", len(self._objects), self.bucket_name)
for object_name in self._objects:
hook.delete(bucket_name=self.bucket_name, object_name=object_name)

def get_openlineage_facets_on_complete(self, task_instance):
"""Implementing on_complete as execute() resolves object names."""
from openlineage.client.facet import (
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
LifecycleStateChangeDatasetFacetPreviousIdentifier,
)
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

if not self._objects:
return OperatorLineage()

bucket_url = f"gs://{self.bucket_name}"
input_datasets = [
Dataset(
namespace=bucket_url,
name=object_name,
facets={
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.DROP.value,
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
namespace=bucket_url,
name=object_name,
),
)
},
)
for object_name in self._objects
]

return OperatorLineage(inputs=input_datasets)


class GCSBucketCreateAclEntryOperator(GoogleCloudBaseOperator):
"""
Expand Down Expand Up @@ -596,6 +631,22 @@ def execute(self, context: Context) -> None:
filename=destination_file.name,
)

def get_openlineage_facets_on_start(self):
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

input_dataset = Dataset(
namespace=f"gs://{self.source_bucket}",
name=self.source_object,
)
output_dataset = Dataset(
namespace=f"gs://{self.destination_bucket}",
name=self.destination_object,
)

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])


class GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
"""
Expand Down Expand Up @@ -722,6 +773,9 @@ def __init__(
self.upload_continue_on_fail = upload_continue_on_fail
self.upload_num_attempts = upload_num_attempts

self._source_object_names: list[str] = []
self._destination_object_names: list[str] = []

def execute(self, context: Context) -> list[str]:
# Define intervals and prefixes.
try:
Expand Down Expand Up @@ -773,7 +827,7 @@ def execute(self, context: Context) -> list[str]:
)

# Fetch list of files.
blobs_to_transform = source_hook.list_by_timespan(
self._source_object_names = source_hook.list_by_timespan(
bucket_name=self.source_bucket,
prefix=source_prefix_interp,
timespan_start=timespan_start,
Expand All @@ -785,7 +839,7 @@ def execute(self, context: Context) -> list[str]:
temp_output_dir_path = Path(temp_output_dir)

# TODO: download in parallel.
for blob_to_transform in blobs_to_transform:
for blob_to_transform in self._source_object_names:
destination_file = temp_input_dir_path / blob_to_transform
destination_file.parent.mkdir(parents=True, exist_ok=True)
try:
Expand Down Expand Up @@ -822,8 +876,6 @@ def execute(self, context: Context) -> list[str]:

self.log.info("Transformation succeeded. Output temporarily located at %s", temp_output_dir_path)

files_uploaded = []

# TODO: upload in parallel.
for upload_file in temp_output_dir_path.glob("**/*"):
if upload_file.is_dir():
Expand All @@ -844,12 +896,35 @@ def execute(self, context: Context) -> list[str]:
chunk_size=self.chunk_size,
num_max_attempts=self.upload_num_attempts,
)
files_uploaded.append(str(upload_file_name))
self._destination_object_names.append(str(upload_file_name))
except GoogleCloudError:
if not self.upload_continue_on_fail:
raise

return files_uploaded
return self._destination_object_names

def get_openlineage_facets_on_complete(self, task_instance):
"""Implementing on_complete as execute() resolves object names."""
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

input_datasets = [
Dataset(
namespace=f"gs://{self.source_bucket}",
name=object_name,
)
for object_name in self._source_object_names
]
output_datasets = [
Dataset(
namespace=f"gs://{self.destination_bucket}",
name=object_name,
)
for object_name in self._destination_object_names
]

return OperatorLineage(inputs=input_datasets, outputs=output_datasets)


class GCSDeleteBucketOperator(GoogleCloudBaseOperator):
Expand Down
160 changes: 160 additions & 0 deletions tests/providers/google/cloud/operators/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from pathlib import Path
from unittest import mock

from openlineage.client.facet import (
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
LifecycleStateChangeDatasetFacetPreviousIdentifier,
)
from openlineage.client.run import Dataset

from airflow.providers.google.cloud.operators.gcs import (
GCSBucketCreateAclEntryOperator,
GCSCreateBucketOperator,
Expand Down Expand Up @@ -164,6 +171,49 @@ def test_delete_prefix_as_empty_string(self, mock_hook):
any_order=True,
)

@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
def test_get_openlineage_facets_on_complete(self, mock_hook):
bucket_url = f"gs://{TEST_BUCKET}"
expected_inputs = [
Dataset(
namespace=bucket_url,
name="folder/a.txt",
facets={
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.DROP.value,
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
namespace=bucket_url,
name="folder/a.txt",
),
)
},
),
Dataset(
namespace=bucket_url,
name="b.txt",
facets={
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.DROP.value,
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
namespace=bucket_url,
name="b.txt",
),
)
},
),
]

operator = GCSDeleteObjectsOperator(
task_id=TASK_ID, bucket_name=TEST_BUCKET, objects=["folder/a.txt", "b.txt"]
)

operator.execute(None)

lineage = operator.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 2
assert len(lineage.outputs) == 0
assert lineage.inputs == expected_inputs


class TestGoogleCloudStorageListOperator:
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
Expand Down Expand Up @@ -251,6 +301,31 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempfile):
filename=destination,
)

def test_get_openlineage_facets_on_start(self):
expected_input = Dataset(
namespace=f"gs://{TEST_BUCKET}",
name="folder/a.txt",
)
expected_output = Dataset(
namespace=f"gs://{TEST_BUCKET}2",
name="b.txt",
)

operator = GCSFileTransformOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_object="folder/a.txt",
destination_bucket=f"{TEST_BUCKET}2",
destination_object="b.txt",
transform_script="/path/to_script",
)

lineage = operator.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0] == expected_input
assert lineage.outputs[0] == expected_output


class TestGCSTimeSpanFileTransformOperatorDateInterpolation:
def test_execute(self):
Expand Down Expand Up @@ -408,6 +483,91 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempdir):
]
)

@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
@mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
def test_get_openlineage_facets_on_complete(self, mock_hook, mock_subprocess, mock_tempdir):
source_bucket = TEST_BUCKET
source_prefix = "source_prefix"

destination_bucket = TEST_BUCKET + "_dest"
destination_prefix = "destination_prefix"
destination = "destination"

file1 = "file1"
file2 = "file2"

timespan_start = datetime(2015, 2, 1, 15, 16, 17, 345, tzinfo=timezone.utc)
mock_dag = mock.Mock()
mock_dag.following_schedule = lambda x: x + timedelta(hours=1)
context = dict(
execution_date=timespan_start,
dag=mock_dag,
ti=mock.Mock(),
)

mock_tempdir.return_value.__enter__.side_effect = ["source", destination]
mock_hook.return_value.list_by_timespan.return_value = [
f"{source_prefix}/{file1}",
f"{source_prefix}/{file2}",
]

mock_proc = mock.MagicMock()
mock_proc.returncode = 0
mock_proc.stdout.readline = lambda: b""
mock_proc.wait.return_value = None
mock_popen = mock.MagicMock()
mock_popen.return_value.__enter__.return_value = mock_proc

mock_subprocess.Popen = mock_popen
mock_subprocess.PIPE = "pipe"
mock_subprocess.STDOUT = "stdout"

op = GCSTimeSpanFileTransformOperator(
task_id=TASK_ID,
source_bucket=source_bucket,
source_prefix=source_prefix,
source_gcp_conn_id="",
destination_bucket=destination_bucket,
destination_prefix=destination_prefix,
destination_gcp_conn_id="",
transform_script="script.py",
)

with mock.patch.object(Path, "glob") as path_glob:
path_glob.return_value.__iter__.return_value = [
Path(f"{destination}/{file1}"),
Path(f"{destination}/{file2}"),
]
op.execute(context=context)

expected_inputs = [
Dataset(
namespace=f"gs://{source_bucket}",
name=f"{source_prefix}/{file1}",
),
Dataset(
namespace=f"gs://{source_bucket}",
name=f"{source_prefix}/{file2}",
),
]
expected_outputs = [
Dataset(
namespace=f"gs://{destination_bucket}",
name=f"{destination_prefix}/{file1}",
),
Dataset(
namespace=f"gs://{destination_bucket}",
name=f"{destination_prefix}/{file2}",
),
]

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 2
assert len(lineage.outputs) == 2
assert lineage.inputs == expected_inputs
assert lineage.outputs == expected_outputs


class TestGCSDeleteBucketOperator:
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
Expand Down