Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,13 @@ def _refresh_cached_properties(self):
del self.pod_manager

def execute_async(self, context: Context) -> None:
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
pod_request_obj=self.pod_request_obj,
context=context,
)
if self.pod_request_obj is None:
self.pod_request_obj = self.build_pod_request_obj(context)
if self.pod is None:
self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
pod_request_obj=self.pod_request_obj,
context=context,
)
if self.callbacks:
pod = self.find_pod(self.pod.metadata.namespace, context=context)
for callback in self.callbacks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,22 +254,23 @@ def find_spark_job(self, context, exclude_checked: bool = True):
self.log.info("`try_number` of pod: %s", pod.metadata.labels["try_number"])
return pod

def get_or_create_spark_crd(self, launcher: CustomObjectLauncher, context) -> k8s.V1Pod:
def get_or_create_spark_crd(self, context) -> k8s.V1Pod:
if self.reattach_on_restart:
driver_pod = self.find_spark_job(context)
if driver_pod:
return driver_pod

driver_pod, spark_obj_spec = launcher.start_spark_job(
driver_pod, spark_obj_spec = self.launcher.start_spark_job(
image=self.image, code_path=self.code_path, startup_timeout=self.startup_timeout_seconds
)
return driver_pod

def process_pod_deletion(self, pod, *, reraise=True):
if pod is not None:
if self.delete_on_termination:
self.log.info("Deleting spark job: %s", pod.metadata.name.replace("-driver", ""))
self.launcher.delete_spark_job(pod.metadata.name.replace("-driver", ""))
pod_name = pod.metadata.name.replace("-driver", "")
self.log.info("Deleting spark job: %s", pod_name)
self.launcher.delete_spark_job(pod_name)
else:
self.log.info("skipping deleting spark job: %s", pod.metadata.name)

Expand All @@ -293,18 +294,22 @@ def client(self) -> CoreV1Api:
def custom_obj_api(self) -> CustomObjectsApi:
return CustomObjectsApi()

def execute(self, context: Context):
self.name = self.create_job_name()

self.log.info("Creating sparkApplication.")
self.launcher = CustomObjectLauncher(
@cached_property
def launcher(self) -> CustomObjectLauncher:
launcher = CustomObjectLauncher(
name=self.name,
namespace=self.namespace,
kube_client=self.client,
custom_obj_api=self.custom_obj_api,
template_body=self.template_body,
)
self.pod = self.get_or_create_spark_crd(self.launcher, context)
return launcher

def execute(self, context: Context):
self.name = self.create_job_name()

self.log.info("Creating sparkApplication.")
self.pod = self.get_or_create_spark_crd(context)
self.pod_request_obj = self.launcher.pod_spec

return super().execute(context=context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from datetime import date
from functools import cached_property
from unittest import mock
from unittest.mock import patch
from unittest.mock import mock_open, patch
from uuid import uuid4

import pendulum
Expand All @@ -32,9 +32,11 @@
from kubernetes.client import models as k8s

from airflow import DAG
from airflow.exceptions import TaskDeferred
from airflow.models import Connection, DagRun, TaskInstance
from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator
from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
from airflow.utils import timezone
from airflow.utils.types import DagRunType

Expand Down Expand Up @@ -754,6 +756,41 @@ def test_find_custom_pod_labels(
op.find_spark_job(context)
mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector)

@pytest.mark.asyncio
def test_execute_deferrable(
self,
mock_create_namespaced_crd,
mock_get_namespaced_custom_object_status,
mock_cleanup,
mock_create_job_name,
mock_get_kube_client,
mock_create_pod,
mock_await_pod_completion,
mock_fetch_requested_container_logs,
data_file,
mocker,
):
task_name = "test_execute_deferrable"
job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text())

mock_create_job_name.return_value = task_name
op = SparkKubernetesOperator(
template_spec=job_spec,
kubernetes_conn_id="kubernetes_default_kube_config",
task_id=task_name,
get_logs=True,
deferrable=True,
)
context = create_context(op)

mock_file = mock_open(read_data='{"a": "b"}')
mocker.patch("builtins.open", mock_file)

with pytest.raises(TaskDeferred) as exc:
op.execute(context)

assert isinstance(exc.value.trigger, KubernetesPodTrigger)


@pytest.mark.db_test
def test_template_body_templating(create_task_instance_of_operator, session):
Expand Down