Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from datetime import datetime, timezone
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
Expand All @@ -29,7 +30,7 @@
from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import CustomObjectLauncher
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN, PodGenerator
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager, PodPhase
from airflow.providers.common.compat.sdk import AirflowException
from airflow.utils.helpers import prune_dict

Expand Down Expand Up @@ -235,15 +236,40 @@ def template_body(self):
return self.manage_template_specs()

def find_spark_job(self, context, exclude_checked: bool = True):
"""
Find an existing Spark driver pod for this task instance.

The pod is identified using Airflow task context labels. If multiple
driver pods match the same labels (which can occur if cleanup did not
run after an abrupt failure), a single pod is selected deterministically
for reattachment, preferring a Running driver pod when present.
"""
label_selector = (
self._build_find_pod_label_selector(context, exclude_checked=exclude_checked)
+ ",spark-role=driver"
)
pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items

pod = None
if len(pod_list) > 1: # and self.reattach_on_restart:
raise AirflowException(f"More than one pod running with labels: {label_selector}")
if len(pod_list) > 1:
# When multiple pods match the same labels, select one deterministically,
# preferring a Running pod, then creation time, with name as a tie-breaker.
pod = max(
pod_list,
key=lambda p: (
p.status.phase == PodPhase.RUNNING,
p.metadata.creation_timestamp or datetime.min.replace(tzinfo=timezone.utc),
p.metadata.name or "",
),
)
self.log.warning(
"Found %d Spark driver pods matching labels %s; "
"selecting pod %s for reattachment based on status and creation time.",
len(pod_list),
label_selector,
pod.metadata.name,
)

if len(pod_list) == 1:
pod = pod_list[0]
self.log.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator
from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase
from airflow.providers.common.compat.sdk import TaskDeferred
from airflow.utils import timezone
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -944,6 +945,170 @@ def test_reattach_on_restart_with_task_context_labels(

mock_create_namespaced_crd.assert_not_called()

def test_find_spark_job_picks_running_pod(
self,
mock_is_in_cluster,
mock_parent_execute,
mock_create_namespaced_crd,
mock_get_namespaced_custom_object_status,
mock_cleanup,
mock_create_job_name,
mock_get_kube_client,
mock_create_pod,
mock_await_pod_completion,
mock_fetch_requested_container_logs,
data_file,
):
"""
Verifies that find_spark_job picks a Running Spark driver pod over a non-Running pod.
"""

task_name = "test_find_spark_job_prefers_running_pod"
job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text())

mock_create_job_name.return_value = task_name
op = SparkKubernetesOperator(
template_spec=job_spec,
kubernetes_conn_id="kubernetes_default_kube_config",
task_id=task_name,
get_logs=True,
reattach_on_restart=True,
)
context = create_context(op)

# Running pod should be selected.
running_pod = mock.MagicMock()
running_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc)
running_pod.metadata.name = "spark-driver-running"
running_pod.metadata.labels = {"try_number": "1"}
running_pod.status.phase = "Running"

# Pending pod should not be selected.
pending_pod = mock.MagicMock()
pending_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc)
pending_pod.metadata.name = "spark-driver-pending"
pending_pod.metadata.labels = {"try_number": "1"}
pending_pod.status.phase = "Pending"

mock_get_kube_client.list_namespaced_pod.return_value.items = [
running_pod,
pending_pod,
]

returned_pod = op.find_spark_job(context)

assert returned_pod is running_pod

def test_find_spark_job_picks_latest_pod(
self,
mock_is_in_cluster,
mock_parent_execute,
mock_create_namespaced_crd,
mock_get_namespaced_custom_object_status,
mock_cleanup,
mock_create_job_name,
mock_get_kube_client,
mock_create_pod,
mock_await_pod_completion,
mock_fetch_requested_container_logs,
data_file,
):
"""
Verifies that find_spark_job selects the most recently created Spark driver pod
when multiple candidate driver pods are present and status does not disambiguate.
"""

task_name = "test_find_spark_job_picks_latest_pod"
job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text())

mock_create_job_name.return_value = task_name
op = SparkKubernetesOperator(
template_spec=job_spec,
kubernetes_conn_id="kubernetes_default_kube_config",
task_id=task_name,
get_logs=True,
reattach_on_restart=True,
)
context = create_context(op)

# Older pod that should be ignored.
old_mock_pod = mock.MagicMock()
old_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc)
old_mock_pod.metadata.name = "spark-driver-old"
old_mock_pod.status.phase = PodPhase.RUNNING

# Newer pod that should be picked up.
new_mock_pod = mock.MagicMock()
new_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 2, tzinfo=timezone.utc)
new_mock_pod.metadata.name = "spark-driver-new"
new_mock_pod.status.phase = PodPhase.RUNNING

# Same try_number to simulate abrupt failure scenarios (e.g. scheduler crash)
# where cleanup did not occur and multiple pods share identical labels.
old_mock_pod.metadata.labels = {"try_number": "1"}
new_mock_pod.metadata.labels = {"try_number": "1"}

mock_get_kube_client.list_namespaced_pod.return_value.items = [old_mock_pod, new_mock_pod]

returned_pod = op.find_spark_job(context)

assert returned_pod is new_mock_pod

def test_find_spark_job_tiebreaks_by_name(
self,
mock_is_in_cluster,
mock_parent_execute,
mock_create_namespaced_crd,
mock_get_namespaced_custom_object_status,
mock_cleanup,
mock_create_job_name,
mock_get_kube_client,
mock_create_pod,
mock_await_pod_completion,
mock_fetch_requested_container_logs,
data_file,
):
"""
Verifies that find_spark_job uses pod name as a deterministic tie-breaker
when multiple running Spark driver pods share the same creation_timestamp.
"""

task_name = "test_find_spark_job_tiebreaks_by_name"
job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text())

mock_create_job_name.return_value = task_name
op = SparkKubernetesOperator(
template_spec=job_spec,
kubernetes_conn_id="kubernetes_default_kube_config",
task_id=task_name,
get_logs=True,
reattach_on_restart=True,
)
context = create_context(op)

# Use identical creation timestamps to force name-based tie-breaking.
ts = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc)

# Pod with lexicographically smaller name should not be selected.
invalid_mock_pod = mock.MagicMock()
invalid_mock_pod.metadata.creation_timestamp = ts
invalid_mock_pod.metadata.name = "spark-driver-abc"
invalid_mock_pod.metadata.labels = {"try_number": "1"}
invalid_mock_pod.status.phase = PodPhase.RUNNING

# Pod with lexicographically greater name should be selected.
valid_mock_pod = mock.MagicMock()
valid_mock_pod.metadata.creation_timestamp = ts
valid_mock_pod.metadata.name = "spark-driver-xyz"
valid_mock_pod.metadata.labels = {"try_number": "1"}
valid_mock_pod.status.phase = PodPhase.RUNNING

mock_get_kube_client.list_namespaced_pod.return_value.items = [invalid_mock_pod, valid_mock_pod]

returned_pod = op.find_spark_job(context)

assert returned_pod is valid_mock_pod

@pytest.mark.asyncio
def test_execute_deferrable(
self,
Expand Down