Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 80 additions & 23 deletions providers/src/airflow/providers/cncf/kubernetes/triggers/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, AsyncIterator

from airflow.exceptions import AirflowException
from airflow.models.taskinstance import TaskInstance
from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook, KubernetesHook
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import provide_session
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from kubernetes.client import V1Job
Expand Down Expand Up @@ -98,6 +102,42 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"do_xcom_push": self.do_xcom_push,
},
)

@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
"""
Get the task instance for the current task.

:param session: Sqlalchemy session
"""
query = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
task_instance = query.one_or_none()
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_instance

def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed by this trigger.

This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
return task_instance.state != TaskInstanceState.DEFERRED


async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current job status and yield a TriggerEvent."""
Expand All @@ -116,39 +156,56 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.log.info("Extracting result from xcom sidecar container.")
loop = asyncio.get_running_loop()
xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace)
job_dict = job.to_dict()
error_message = self.hook.is_job_failed(job=job)
yield TriggerEvent(
{
"name": job.metadata.name,
"namespace": job.metadata.namespace,
"pod_name": pod.metadata.name if self.get_logs else None,
"pod_namespace": pod.metadata.namespace if self.get_logs else None,
"status": "error" if error_message else "success",
"message": f"Job failed with error: {error_message}"
if error_message
else "Job completed successfully",
"job": job_dict,
"xcom_result": xcom_result if self.do_xcom_push else None,
}
)
try:
job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace)
job_dict = job.to_dict()
error_message = self.hook.is_job_failed(job=job)

yield TriggerEvent(
{
"name": job.metadata.name,
"namespace": job.metadata.namespace,
"pod_name": pod.metadata.name if self.get_logs else None,
"pod_namespace": pod.metadata.namespace if self.get_logs else None,
"status": "error" if error_message else "success",
"message": f"Job failed with error: {error_message}"
if error_message
else "Job completed successfully",
"job": job_dict,
"xcom_result": xcom_result if self.do_xcom_push else None,
}
)
except asyncio.CancelledError:
try:
if self.safe_to_cancel():
self.log.info("Deleting job from kubernetes cluster")
resp = self.sync_hook.batch_v1_client.delete_namespaced_job(
name=self.job_name,
namespace=self.job_namespace,
)
self.log.info("Deleting job response: %s", resp)
except Exception as e:
self.log.error("Error during cancellation handling: %s", e)
raise AirflowException("Error during cancellation handling: %s", e)

@cached_property
def hook(self) -> AsyncKubernetesHook:
return AsyncKubernetesHook(
def sync_hook(self) -> KubernetesHook:
return KubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
cluster_context=self.cluster_context,
)

@cached_property
def pod_manager(self) -> PodManager:
sync_hook = KubernetesHook(
def hook(self) -> AsyncKubernetesHook:
return AsyncKubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
cluster_context=self.cluster_context,
)
return PodManager(kube_client=sync_hook.core_v1_client)

@cached_property
def pod_manager(self) -> PodManager:
return PodManager(kube_client=self.sync_hook.core_v1_client)