Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@
},
"apache.spark": {
"deps": [
"apache-airflow-providers-common-compat>=1.5.0",
"apache-airflow>=2.9.0",
"grpcio-status>=1.59.0",
"pyspark>=3.1.3"
Expand Down
15 changes: 8 additions & 7 deletions providers/apache/spark/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@ The package supports the following python versions: 3.9,3.10,3.11,3.12
Requirements
------------

================== ==================
PIP package Version required
================== ==================
``apache-airflow`` ``>=2.9.0``
``pyspark`` ``>=3.1.3``
``grpcio-status`` ``>=1.59.0``
================== ==================
========================================== ==================
PIP package Version required
========================================== ==================
``apache-airflow`` ``>=2.9.0``
``apache-airflow-providers-common-compat`` ``>=1.5.0``
``pyspark`` ``>=3.1.3``
``grpcio-status`` ``>=1.59.0``
========================================== ==================

Cross provider package dependencies
-----------------------------------
Expand Down
4 changes: 1 addition & 3 deletions providers/apache/spark/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ requires-python = "~=3.9"
# After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
dependencies = [
"apache-airflow>=2.9.0",
"apache-airflow-providers-common-compat>=1.5.0",
"pyspark>=3.1.3",
"grpcio-status>=1.59.0",
]
Expand All @@ -68,9 +69,6 @@ dependencies = [
"cncf.kubernetes" = [
"apache-airflow-providers-cncf-kubernetes>=7.4.0",
]
"common.compat" = [
"apache-airflow-providers-common-compat"
]

[dependency-groups]
dev = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ def get_provider_info():
"name": "pyspark",
}
],
"dependencies": ["apache-airflow>=2.9.0", "pyspark>=3.1.3", "grpcio-status>=1.59.0"],
"optional-dependencies": {
"cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.4.0"],
"common.compat": ["apache-airflow-providers-common-compat"],
},
"dependencies": [
"apache-airflow>=2.9.0",
"apache-airflow-providers-common-compat>=1.5.0",
"pyspark>=3.1.3",
"grpcio-status>=1.59.0",
],
"optional-dependencies": {"cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.4.0"]},
"devel-dependencies": [],
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
from airflow.providers.common.compat.openlineage.utils.spark import (
inject_parent_job_information_into_spark_properties,
inject_transport_information_into_spark_properties,
)
from airflow.settings import WEB_COLORS

if TYPE_CHECKING:
Expand Down Expand Up @@ -135,6 +140,12 @@ def __init__(
yarn_queue: str | None = None,
deploy_mode: str | None = None,
use_krb5ccache: bool = False,
openlineage_inject_parent_job_info: bool = conf.getboolean(
"openlineage", "spark_inject_parent_job_info", fallback=False
),
openlineage_inject_transport_info: bool = conf.getboolean(
"openlineage", "spark_inject_transport_info", fallback=False
),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -169,9 +180,17 @@ def __init__(
self._hook: SparkSubmitHook | None = None
self._conn_id = conn_id
self._use_krb5ccache = use_krb5ccache
self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info
self._openlineage_inject_transport_info = openlineage_inject_transport_info

def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job."""
if self._openlineage_inject_parent_job_info:
self.log.debug("Injecting OpenLineage parent job information into Spark properties.")
self.conf = inject_parent_job_information_into_spark_properties(self.conf, context)
if self._openlineage_inject_transport_info:
self.log.debug("Injecting OpenLineage transport information into Spark properties.")
self.conf = inject_transport_information_into_spark_properties(self.conf, context)
if self._hook is None:
self._hook = self._get_hook()
self._hook.submit(self.application)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
# under the License.
from __future__ import annotations

import logging
from datetime import timedelta
from unittest import mock
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -281,3 +284,179 @@ def test_templating_with_create_task_instance_of_operator(
assert task.application_args == "application_args"
assert task.env_vars == "env_vars"
assert task.properties_file == "properties_file"

@mock.patch.object(SparkSubmitOperator, "_get_hook")
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
def test_inject_simple_openlineage_config_to_spark(self, mock_get_openlineage_listener, mock_get_hook):
# Given / When
from openlineage.client.transport.http import (
ApiKeyTokenProvider,
HttpCompression,
HttpConfig,
HttpTransport,
)

mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport(
config=HttpConfig(
url="http://localhost:5000",
endpoint="api/v2/lineage",
timeout=5050,
auth=ApiKeyTokenProvider({"api_key": "12345"}),
compression=HttpCompression.GZIP,
custom_headers={"X-OpenLineage-Custom-Header": "airflow"},
)
)
operator = SparkSubmitOperator(
task_id="spark_submit_job",
spark_binary="sparky",
dag=self.dag,
openlineage_inject_parent_job_info=False,
openlineage_inject_transport_info=True,
**self._config,
)
operator.execute(MagicMock())

assert operator.conf == {
"parquet.compression": "SNAPPY",
"spark.openlineage.transport.type": "http",
"spark.openlineage.transport.url": "http://localhost:5000",
"spark.openlineage.transport.endpoint": "api/v2/lineage",
"spark.openlineage.transport.timeoutInMillis": "5050000",
"spark.openlineage.transport.compression": "gzip",
"spark.openlineage.transport.auth.type": "api_key",
"spark.openlineage.transport.auth.apiKey": "Bearer 12345",
"spark.openlineage.transport.headers.X-OpenLineage-Custom-Header": "airflow",
}

@mock.patch.object(SparkSubmitOperator, "_get_hook")
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
def test_inject_composite_openlineage_config_to_spark(self, mock_get_openlineage_listener, mock_get_hook):
# Given / When
from openlineage.client.transport.composite import CompositeConfig, CompositeTransport

mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = CompositeTransport(
CompositeConfig.from_dict(
{
"transports": {
"test1": {
"type": "http",
"url": "http://localhost:5000",
"endpoint": "api/v2/lineage",
"timeout": "5050",
"auth": {
"type": "api_key",
"api_key": "12345",
},
"compression": "gzip",
"custom_headers": {
"X-OpenLineage-Custom-Header": "airflow",
},
},
"test2": {
"type": "http",
"url": "https://example.com:1234",
},
"test3": {"type": "console"},
}
}
)
)

mock_ti = MagicMock()
mock_ti.dag_id = "test_dag_id"
mock_ti.task_id = "spark_submit_job"
mock_ti.try_number = 1
mock_ti.dag_run.logical_date = DEFAULT_DATE
mock_ti.dag_run.run_after = DEFAULT_DATE
mock_ti.logical_date = DEFAULT_DATE
mock_ti.map_index = -1

operator = SparkSubmitOperator(
task_id="spark_submit_job",
spark_binary="sparky",
dag=self.dag,
openlineage_inject_parent_job_info=True,
openlineage_inject_transport_info=True,
**self._config,
)
operator.execute({"ti": mock_ti})

assert operator.conf == {
"parquet.compression": "SNAPPY",
"spark.openlineage.parentJobName": "test_dag_id.spark_submit_job",
"spark.openlineage.parentJobNamespace": "default",
"spark.openlineage.parentRunId": "01595753-6400-710b-8a12-9e978335a56d",
"spark.openlineage.transport.type": "composite",
"spark.openlineage.transport.continueOnFailure": "True",
"spark.openlineage.transport.transports.test1.type": "http",
"spark.openlineage.transport.transports.test1.url": "http://localhost:5000",
"spark.openlineage.transport.transports.test1.endpoint": "api/v2/lineage",
"spark.openlineage.transport.transports.test1.timeoutInMillis": "5050000",
"spark.openlineage.transport.transports.test1.auth.type": "api_key",
"spark.openlineage.transport.transports.test1.auth.apiKey": "Bearer 12345",
"spark.openlineage.transport.transports.test1.compression": "gzip",
"spark.openlineage.transport.transports.test1.headers.X-OpenLineage-Custom-Header": "airflow",
"spark.openlineage.transport.transports.test2.type": "http",
"spark.openlineage.transport.transports.test2.url": "https://example.com:1234",
"spark.openlineage.transport.transports.test2.endpoint": "api/v1/lineage",
"spark.openlineage.transport.transports.test2.timeoutInMillis": "5000",
}

@mock.patch.object(SparkSubmitOperator, "_get_hook")
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
def test_inject_openlineage_composite_config_wrong_transport_to_spark(
self, mock_get_openlineage_listener, mock_get_hook, caplog
):
# Given / When
from openlineage.client.transport.composite import CompositeConfig, CompositeTransport

mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = CompositeTransport(
CompositeConfig.from_dict({"transports": {"test1": {"type": "console"}}})
)

with caplog.at_level(logging.INFO):
operator = SparkSubmitOperator(
task_id="spark_submit_job",
spark_binary="sparky",
dag=self.dag,
openlineage_inject_parent_job_info=False,
openlineage_inject_transport_info=True,
**self._config,
)
operator.execute(MagicMock())

assert (
"OpenLineage transport type `composite` does not contain http transport. Skipping injection of OpenLineage transport information into Spark properties."
in caplog.text
)
assert operator.conf == {
"parquet.compression": "SNAPPY",
}

@mock.patch.object(SparkSubmitOperator, "_get_hook")
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
def test_inject_openlineage_simple_config_wrong_transport_to_spark(
self, mock_get_openlineage_listener, mock_get_hook, caplog
):
# Given / When
from openlineage.client.transport.console import ConsoleConfig, ConsoleTransport

mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = ConsoleTransport(
config=ConsoleConfig()
)

with caplog.at_level(logging.INFO):
operator = SparkSubmitOperator(
task_id="spark_submit_job",
spark_binary="sparky",
dag=self.dag,
openlineage_inject_parent_job_info=False,
openlineage_inject_transport_info=True,
**self._config,
)
operator.execute(MagicMock())

assert "OpenLineage transport type `console` does not support automatic injection of OpenLineage transport information into Spark properties."
assert operator.conf == {
"parquet.compression": "SNAPPY",
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _get_transport_information(tp) -> dict:
"url": tp.url,
"endpoint": tp.endpoint,
"timeoutInMillis": str(
int(tp.timeout * 1000) # convert to milliseconds, as required by Spark integration
int(tp.timeout) * 1000 # convert to milliseconds, as required by Spark integration
),
}
if hasattr(tp, "compression") and tp.compression:
Expand Down