Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,94 +34,20 @@
inject_transport_information_into_spark_properties,
)
except ImportError:
try:
from airflow.providers.openlineage.plugins.macros import (
lineage_job_name,
lineage_job_namespace,
lineage_run_id,
)
except ImportError:

def inject_parent_job_information_into_spark_properties(properties: dict, context) -> dict:
log.warning(
"Could not import `airflow.providers.openlineage.plugins.macros`."
"Skipping the injection of OpenLineage parent job information into Spark properties."
)
return properties

else:

def inject_parent_job_information_into_spark_properties(properties: dict, context) -> dict:
if any(str(key).startswith("spark.openlineage.parent") for key in properties):
log.info(
"Some OpenLineage properties with parent job information are already present "
"in Spark properties. Skipping the injection of OpenLineage "
"parent job information into Spark properties."
)
return properties

ti = context["ti"]
ol_parent_job_properties = {
"spark.openlineage.parentJobNamespace": lineage_job_namespace(),
"spark.openlineage.parentJobName": lineage_job_name(ti),
"spark.openlineage.parentRunId": lineage_run_id(ti),
}
return {**properties, **ol_parent_job_properties}

try:
from airflow.providers.openlineage.plugins.listener import get_openlineage_listener
except ImportError:

def inject_transport_information_into_spark_properties(properties: dict, context) -> dict:
log.warning(
"Could not import `airflow.providers.openlineage.plugins.listener`."
"Skipping the injection of OpenLineage transport information into Spark properties."
)
return properties

else:

def inject_transport_information_into_spark_properties(properties: dict, context) -> dict:
if any(str(key).startswith("spark.openlineage.transport") for key in properties):
log.info(
"Some OpenLineage properties with transport information are already present "
"in Spark properties. Skipping the injection of OpenLineage "
"transport information into Spark properties."
)
return properties

transport = get_openlineage_listener().adapter.get_or_create_openlineage_client().transport
if transport.kind != "http":
log.info(
"OpenLineage transport type `%s` does not support automatic "
"injection of OpenLineage transport information into Spark properties.",
transport.kind,
)
return {}

transport_properties = {
"spark.openlineage.transport.type": "http",
"spark.openlineage.transport.url": transport.url,
"spark.openlineage.transport.endpoint": transport.endpoint,
# Timeout is converted to milliseconds, as required by Spark integration,
"spark.openlineage.transport.timeoutInMillis": str(int(transport.timeout * 1000)),
}
if transport.compression:
transport_properties["spark.openlineage.transport.compression"] = str(
transport.compression
)

if hasattr(transport.config.auth, "api_key") and transport.config.auth.get_bearer():
transport_properties["spark.openlineage.transport.auth.type"] = "api_key"
transport_properties["spark.openlineage.transport.auth.apiKey"] = (
transport.config.auth.get_bearer()
)

if hasattr(transport.config, "custom_headers") and transport.config.custom_headers:
for key, value in transport.config.custom_headers.items():
transport_properties[f"spark.openlineage.transport.headers.{key}"] = value
def inject_parent_job_information_into_spark_properties(properties: dict, context) -> dict:
log.warning(
"Could not import `airflow.providers.openlineage.plugins.macros`."
"Skipping the injection of OpenLineage parent job information into Spark properties."
)
return properties

return {**properties, **transport_properties}
def inject_transport_information_into_spark_properties(properties: dict, context) -> dict:
log.warning(
"Could not import `airflow.providers.openlineage.plugins.listener`."
"Skipping the injection of OpenLineage transport information into Spark properties."
)
return properties


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
EXAMPLE_CONTEXT = {
"ti": MagicMock(
dag_id="dag_id",
dag_run=MagicMock(run_after=dt.datetime(2024, 11, 11), logical_date=dt.datetime(2024, 11, 11)),
task_id="task_id",
try_number=1,
map_index=1,
Expand Down Expand Up @@ -574,7 +575,8 @@ def test_replace_dataproc_job_properties_key_error():
def test_inject_openlineage_properties_into_dataproc_job_provider_not_accessible(mock_is_accessible):
mock_is_accessible.return_value = False
job = {"sparkJob": {"properties": {"existingProperty": "value"}}}
result = inject_openlineage_properties_into_dataproc_job(job, None, True, True)

result = inject_openlineage_properties_into_dataproc_job(job, EXAMPLE_CONTEXT, True, True)
assert result == job


Expand All @@ -586,7 +588,7 @@ def test_inject_openlineage_properties_into_dataproc_job_unsupported_job_type(
mock_is_accessible.return_value = True
mock_extract_job_type.return_value = None
job = {"unsupportedJob": {"properties": {"existingProperty": "value"}}}
result = inject_openlineage_properties_into_dataproc_job(job, None, True, True)
result = inject_openlineage_properties_into_dataproc_job(job, EXAMPLE_CONTEXT, True, True)
assert result == job


Expand All @@ -599,7 +601,9 @@ def test_inject_openlineage_properties_into_dataproc_job_no_injection(
mock_extract_job_type.return_value = "sparkJob"
inject_parent_job_info = False
job = {"sparkJob": {"properties": {"existingProperty": "value"}}}
result = inject_openlineage_properties_into_dataproc_job(job, None, inject_parent_job_info, False)
result = inject_openlineage_properties_into_dataproc_job(
job, EXAMPLE_CONTEXT, inject_parent_job_info, False
)
assert result == job


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1559,9 +1559,13 @@ def test_dataproc_operator_execute_async_done_before_defer(self, mock_submit_job
op.execute(context=self.mock_context)
assert not mock_defer.called

@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_accessible):
def test_execute_openlineage_parent_job_info_injection(
self, mock_hook, mock_ol_accessible, mock_static_uuid
):
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
job_config = {
"placement": {"cluster_name": CLUSTER_NAME},
"pyspark_job": {
Expand Down Expand Up @@ -1620,13 +1624,15 @@ def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_
metadata=METADATA,
)

@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_http_transport_info_injection(
self, mock_hook, mock_ol_accessible, mock_ol_listener
self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid
):
mock_ol_accessible.return_value = True
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport(
HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
)
Expand Down Expand Up @@ -1673,11 +1679,15 @@ def test_execute_openlineage_http_transport_info_injection(
metadata=METADATA,
)

@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_all_info_injection(self, mock_hook, mock_ol_accessible, mock_ol_listener):
def test_execute_openlineage_all_info_injection(
self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid
):
mock_ol_accessible.return_value = True
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport(
HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
)
Expand Down Expand Up @@ -2705,10 +2715,14 @@ def test_wait_for_operation_on_execute(self, mock_hook):
)
mock_op.return_value.result.assert_not_called()

@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_accessible):
def test_execute_openlineage_parent_job_info_injection(
self, mock_hook, mock_ol_accessible, mock_static_uuid
):
mock_ol_accessible.return_value = True
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
template = {
**WORKFLOW_TEMPLATE,
"jobs": [
Expand Down Expand Up @@ -2891,13 +2905,15 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_acces
metadata=METADATA,
)

@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_transport_info_injection(
self, mock_hook, mock_ol_accessible, mock_ol_listener
self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid
):
mock_ol_accessible.return_value = True
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport(
HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
)
Expand Down Expand Up @@ -2995,11 +3011,15 @@ def test_execute_openlineage_transport_info_injection(
metadata=METADATA,
)

@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_all_info_injection(self, mock_hook, mock_ol_accessible, mock_ol_listener):
def test_execute_openlineage_all_info_injection(
self, mock_hook, mock_ol_accessible, mock_ol_listener, mock_static_uuid
):
mock_ol_accessible.return_value = True
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport(
HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
)
Expand Down Expand Up @@ -3467,11 +3487,15 @@ def test_execute_batch_already_exists_cancelled(self, mock_hook):
metadata=METADATA,
)

@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_parent_job_info_injection(self, mock_hook, to_dict_mock, mock_ol_accessible):
def test_execute_openlineage_parent_job_info_injection(
self, mock_hook, to_dict_mock, mock_ol_accessible, mock_static_uuid
):
mock_ol_accessible.return_value = True
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
expected_batch = {
**BATCH,
"runtime_config": {"properties": OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES},
Expand Down Expand Up @@ -3504,14 +3528,16 @@ def test_execute_openlineage_parent_job_info_injection(self, mock_hook, to_dict_
metadata=METADATA,
)

@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_transport_info_injection(
self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener
self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, mock_static_uuid
):
mock_ol_accessible.return_value = True
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport(
HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
)
Expand Down Expand Up @@ -3547,14 +3573,16 @@ def test_execute_openlineage_transport_info_injection(
metadata=METADATA,
)

@mock.patch("airflow.providers.openlineage.plugins.adapter.generate_static_uuid")
@mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener")
@mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible")
@mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_openlineage_all_info_injection(
self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener
self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener, mock_static_uuid
):
mock_ol_accessible.return_value = True
mock_static_uuid.return_value = "01931885-2800-7be7-aa8d-aaa15c337267"
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport(
HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
)
Expand Down Expand Up @@ -3603,8 +3631,15 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_already_pres
self, mock_hook, to_dict_mock, mock_ol_accessible
):
mock_ol_accessible.return_value = True
expected_labels = {
"airflow-dag-id": "test_dag",
"airflow-dag-display-name": "test_dag",
"airflow-task-id": "task-id",
}

batch = {
**BATCH,
"labels": expected_labels,
"runtime_config": {
"properties": {
"spark.openlineage.parentJobName": "dag_id.task_id",
Expand All @@ -3625,6 +3660,7 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_already_pres
timeout=TIMEOUT,
metadata=METADATA,
openlineage_inject_parent_job_info=True,
dag=DAG(dag_id="test_dag"),
)
mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED)
op.execute(context=EXAMPLE_CONTEXT)
Expand All @@ -3650,8 +3686,16 @@ def test_execute_openlineage_transport_info_injection_skipped_when_already_prese
mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport(
HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG)
)

expected_labels = {
"airflow-dag-id": "test_dag",
"airflow-dag-display-name": "test_dag",
"airflow-task-id": "task-id",
}

batch = {
**BATCH,
"labels": expected_labels,
"runtime_config": {
"properties": {
"spark.openlineage.transport.type": "console",
Expand All @@ -3672,6 +3716,7 @@ def test_execute_openlineage_transport_info_injection_skipped_when_already_prese
timeout=TIMEOUT,
metadata=METADATA,
openlineage_inject_transport_info=True,
dag=DAG(dag_id="test_dag"),
)
mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED)
op.execute(context=EXAMPLE_CONTEXT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter
from airflow.providers.openlineage.utils.utils import get_job_name
from airflow.providers.openlineage.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
from airflow.models import TaskInstance
Expand Down Expand Up @@ -58,15 +59,25 @@ def lineage_run_id(task_instance: TaskInstance):
For more information take a look at the guide:
:ref:`howto/macros:openlineage`
"""
if hasattr(task_instance, "logical_date"):
logical_date = task_instance.logical_date
if AIRFLOW_V_3_0_PLUS:
context = task_instance.get_template_context()
if hasattr(task_instance, "dag_run"):
dag_run = task_instance.dag_run
elif hasattr(context, "dag_run"):
dag_run = context["dag_run"]
if hasattr(dag_run, "logical_date") and dag_run.logical_date:
date = dag_run.logical_date
else:
date = dag_run.run_after
elif hasattr(task_instance, "logical_date"):
date = task_instance.logical_date
else:
logical_date = task_instance.execution_date
date = task_instance.execution_date
return OpenLineageAdapter.build_task_instance_run_id(
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
try_number=task_instance.try_number,
logical_date=logical_date,
logical_date=date,
map_index=task_instance.map_index,
)

Expand Down
Loading