Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,16 @@ def client(self) -> CoreV1Api:
def custom_obj_api(self) -> CustomObjectsApi:
return CustomObjectsApi()

@cached_property
def launcher(self) -> CustomObjectLauncher:
return CustomObjectLauncher(
name=self.name,
namespace=self.namespace,
kube_client=self.client,
custom_obj_api=self.custom_obj_api,
template_body=self.template_body,
)

def get_or_create_spark_crd(self, launcher: CustomObjectLauncher, context) -> k8s.V1Pod:
if self.reattach_on_restart:
driver_pod = self.find_spark_job(context)
Expand Down Expand Up @@ -323,6 +333,8 @@ def _setup_spark_configuration(self, context: Context):
)
self.pod = existing_pod
self.pod_request_obj = None
if self.pod.metadata.name.endswith("-driver"):
self.name = self.pod.metadata.name.removesuffix("-driver")
return

if "spark" not in template_body:
Expand Down Expand Up @@ -361,9 +373,12 @@ def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool =
return self.find_spark_job(context, exclude_checked=exclude_checked)

def on_kill(self) -> None:
if self.launcher:
self.log.debug("Deleting spark job for task %s", self.task_id)
self.launcher.delete_spark_job()
self.log.debug("Deleting spark job for task %s", self.task_id)
job_name = self.name
if self.pod and self.pod.metadata and self.pod.metadata.name:
if self.pod.metadata.name.endswith("-driver"):
job_name = self.pod.metadata.name.removesuffix("-driver")
self.launcher.delete_spark_job(spark_job_name=job_name)

def patch_already_checked(self, pod: k8s.V1Pod, *, reraise=True):
"""Add an "already checked" annotation to ensure we don't reattach on retries."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1149,3 +1149,102 @@ def test_create_job_name_should_truncate_long_names():
pod_name = op.create_job_name()

assert pod_name == long_name[:MAX_LABEL_LEN]


class TestSparkKubernetesLifecycle:
@mock.patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.CustomObjectLauncher")
@mock.patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
def test_launcher_access_without_execute(self, mock_hook, mock_launcher_cls):
"""Test that launcher is accessible even if execute is not called (e.g. after deferral)."""
op = SparkKubernetesOperator(
task_id="test_task",
namespace="default",
application_file="example.yaml",
kubernetes_conn_id="kubernetes_default",
)

# Mock the template body loading since we don't have a real file
with mock.patch.object(SparkKubernetesOperator, "manage_template_specs") as mock_manage:
mock_manage.return_value = {"spark": {"spec": {}}}

# Access launcher
launcher = op.launcher

assert launcher is not None
assert mock_launcher_cls.called

@mock.patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.CustomObjectLauncher")
@mock.patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
def test_on_kill_works_without_execute(self, mock_hook, mock_launcher_cls):
"""Test that on_kill works without execute being called."""
op = SparkKubernetesOperator(
task_id="test_task",
namespace="default",
application_file="example.yaml",
name="test-job",
)

mock_launcher_instance = mock_launcher_cls.return_value

with mock.patch.object(SparkKubernetesOperator, "manage_template_specs") as mock_manage:
mock_manage.return_value = {"spark": {"spec": {}}}

op.on_kill()

# Should call delete_spark_job on the launcher
mock_launcher_instance.delete_spark_job.assert_called_once()

# Check arguments
call_args = mock_launcher_instance.delete_spark_job.call_args
# We expect spark_job_name="test-job"
assert call_args.kwargs.get("spark_job_name") == "test-job"

@mock.patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.CustomObjectLauncher")
@mock.patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
@mock.patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesPodOperator.execute")
def test_reattach_skips_launcher_creation_in_execute(
self, mock_super_execute, mock_hook, mock_launcher_cls
):
"""Test that reattach logic skips explicit launcher creation but property still works."""
op = SparkKubernetesOperator(
task_id="test_task",
namespace="default",
application_file="example.yaml",
reattach_on_restart=True,
)

# Mock finding an existing pod
mock_pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="existing-pod"))

with (
mock.patch.object(SparkKubernetesOperator, "find_spark_job", return_value=mock_pod),
mock.patch.object(
SparkKubernetesOperator, "manage_template_specs", return_value={"spark": {"spec": {}}}
),
mock.patch.object(SparkKubernetesOperator, "_get_ti_pod_labels", return_value={}),
):
context = {"ti": mock.MagicMock(), "run_id": "test_run"}

# Run execute
op.execute(context)

# Verify super().execute was called
mock_super_execute.assert_called_once()

# Verify launcher was NOT instantiated during execute (because we returned early)
# We can check if the mock_launcher_cls was instantiated.
# It should NOT be instantiated during execute because _setup_spark_configuration returns early.
# However, accessing op.launcher later WILL instantiate it.

# Reset mock to clear any previous calls (though there shouldn't be any)
mock_launcher_cls.reset_mock()

# Access launcher now
assert op.launcher is not None

# Now it should have been instantiated
mock_launcher_cls.assert_called_once()

# And verify delete works
op.on_kill()
mock_launcher_cls.return_value.delete_spark_job.assert_called()