Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
with contextlib.suppress(ImportError, NameError):
from airflow.kubernetes import kube_client

ALLOWED_SPARK_BINARIES = ["spark-submit", "spark2-submit"]
ALLOWED_SPARK_BINARIES = ["spark-submit", "spark2-submit", "spark3-submit"]


class SparkSubmitHook(BaseHook, LoggingMixin):
Expand Down Expand Up @@ -78,7 +78,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
supports yarn and k8s mode too.
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit.
Some distros may use spark2-submit or spark3-submit.
"""

conn_name_attr = "conn_id"
Expand Down Expand Up @@ -206,15 +206,16 @@ def _resolve_connection(self) -> dict[str, Any]:
spark_binary = self._spark_binary or extra.get("spark-binary", "spark-submit")
if spark_binary not in ALLOWED_SPARK_BINARIES:
raise RuntimeError(
f"The `spark-binary` extra can be on of {ALLOWED_SPARK_BINARIES} and it"
f"The `spark-binary` extra can be one of {ALLOWED_SPARK_BINARIES} and it"
f" was `{spark_binary}`. Please make sure your spark binary is one of the"
" allowed ones and that it is available on the PATH"
)
conn_spark_home = extra.get("spark-home")
if conn_spark_home:
raise RuntimeError(
"The `spark-home` extra is not allowed any more. Please make sure your `spark-submit` or"
" `spark2-submit` are available on the PATH."
"The `spark-home` extra is not allowed any more. Please make sure one of"
f" {ALLOWED_SPARK_BINARIES} is available on the PATH, and set `spark-binary`"
" if needed."
)
conn_data["spark_binary"] = spark_binary
conn_data["namespace"] = extra.get("namespace")
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/spark/operators/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class SparkSubmitOperator(BaseOperator):
:param env_vars: Environment variables for spark-submit. It supports yarn and k8s mode too. (templated)
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit.
Some distros may use spark2-submit or spark3-submit.
"""

template_fields: Sequence[str] = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Extra (optional)

* ``queue`` - The name of the YARN queue to which the application is submitted.
* ``deploy-mode`` - Whether to deploy your driver on the worker nodes (cluster) or locally as an external client (client).
* ``spark-binary`` - The command to use for Spark submit. Some distros may use ``spark2-submit``. Default ``spark-submit``. Only ``spark-submit`` and ``spark2-submit`` are allowed as value.
* ``spark-binary`` - The command to use for Spark submit. Some distros may use ``spark2-submit``. Default ``spark-submit``. Only ``spark-submit``, ``spark2-submit`` or ``spark3-submit`` are allowed as value.
* ``namespace`` - Kubernetes namespace (``spark.kubernetes.namespace``) to divide cluster resources between multiple users (via resource quota).

When specifying the connection in environment variable you should specify
Expand Down
33 changes: 30 additions & 3 deletions tests/providers/apache/spark/hooks/test_spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ def setup_method(self):
extra='{"spark-binary": "spark2-submit"}',
)
)
db.merge_conn(
Connection(
conn_id="spark_binary_set_spark3_submit",
conn_type="spark",
host="yarn",
extra='{"spark-binary": "spark3-submit"}',
)
)
db.merge_conn(
Connection(
conn_id="spark_custom_binary_set",
Expand Down Expand Up @@ -434,6 +442,25 @@ def test_resolve_connection_spark_binary_set_connection(self):
assert connection == expected_spark_connection
assert cmd[0] == "spark2-submit"

def test_resolve_connection_spark_binary_spark3_submit_set_connection(self):
# Given
hook = SparkSubmitHook(conn_id="spark_binary_set_spark3_submit")

# When
connection = hook._resolve_connection()
cmd = hook._build_spark_submit_command(self._spark_job_file)

# Then
expected_spark_connection = {
"master": "yarn",
"spark_binary": "spark3-submit",
"deploy_mode": None,
"queue": None,
"namespace": None,
}
assert connection == expected_spark_connection
assert cmd[0] == "spark3-submit"

def test_resolve_connection_custom_spark_binary_not_allowed_runtime_error(self):
with pytest.raises(RuntimeError):
SparkSubmitHook(conn_id="spark_binary_set", spark_binary="another-custom-spark-submit")
Expand All @@ -448,7 +475,7 @@ def test_resolve_connection_spark_home_not_allowed_runtime_error(self):

def test_resolve_connection_spark_binary_default_value_override(self):
# Given
hook = SparkSubmitHook(conn_id="spark_binary_set", spark_binary="spark2-submit")
hook = SparkSubmitHook(conn_id="spark_binary_set", spark_binary="spark3-submit")

# When
connection = hook._resolve_connection()
Expand All @@ -457,13 +484,13 @@ def test_resolve_connection_spark_binary_default_value_override(self):
# Then
expected_spark_connection = {
"master": "yarn",
"spark_binary": "spark2-submit",
"spark_binary": "spark3-submit",
"deploy_mode": None,
"queue": None,
"namespace": None,
}
assert connection == expected_spark_connection
assert cmd[0] == "spark2-submit"
assert cmd[0] == "spark3-submit"

def test_resolve_connection_spark_binary_default_value(self):
# Given
Expand Down