Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions airflow/providers/google/cloud/triggers/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence

from google.cloud.container_v1.types import Operation
from packaging.version import parse as parse_version

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodManager
from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
Expand All @@ -33,6 +34,7 @@
GKEKubernetesAsyncHook,
GKEKubernetesHook,
)
from airflow.providers_manager import ProvidersManager
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
Expand Down Expand Up @@ -305,19 +307,35 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
if self.get_logs or self.do_xcom_push:
pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace)
if self.do_xcom_push:
kubernetes_provider = ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
kubernetes_provider_name = kubernetes_provider.data["package-name"]
kubernetes_provider_version = kubernetes_provider.version
min_version = "8.4.1"
if parse_version(kubernetes_provider_version) < parse_version(min_version):
raise AirflowException(
"You are trying to use do_xcom_push in `GKEStartJobOperator` with the provider "
f"package {kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
f"support this feature. Please upgrade it to version higher than or equal to {min_version}."
)
await self.hook.wait_until_container_complete(
name=self.pod_name, namespace=self.pod_namespace, container_name=self.base_container_name
name=self.pod_name,
namespace=self.pod_namespace,
container_name=self.base_container_name,
poll_interval=self.poll_interval,
)
self.log.info("Checking if xcom sidecar container is started.")
await self.hook.wait_until_container_started(
name=self.pod_name,
namespace=self.pod_namespace,
container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
poll_interval=self.poll_interval,
)
self.log.info("Extracting result from xcom sidecar container.")
loop = asyncio.get_running_loop()
xcom_result = await loop.run_in_executor(None, self.pod_manager.extract_xcom, pod)
job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace)
job: V1Job = await self.hook.wait_until_job_complete(
name=self.job_name, namespace=self.job_namespace, poll_interval=self.poll_interval
)
job_dict = job.to_dict()
error_message = self.hook.is_job_failed(job=job)
status = "error" if error_message else "success"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,9 @@ async def test_run_success(self, mock_hook, job_trigger):

event_actual = await job_trigger.run().asend(None)

mock_hook.wait_until_job_complete.assert_called_once_with(name=JOB_NAME, namespace=NAMESPACE)
mock_hook.wait_until_job_complete.assert_called_once_with(
name=JOB_NAME, namespace=NAMESPACE, poll_interval=POLL_INTERVAL
)
mock_job.to_dict.assert_called_once()
mock_is_job_failed.assert_called_once_with(job=mock_job)
assert event_actual == TriggerEvent(
Expand Down Expand Up @@ -544,7 +546,9 @@ async def test_run_fail(self, mock_hook, job_trigger):

event_actual = await job_trigger.run().asend(None)

mock_hook.wait_until_job_complete.assert_called_once_with(name=JOB_NAME, namespace=NAMESPACE)
mock_hook.wait_until_job_complete.assert_called_once_with(
name=JOB_NAME, namespace=NAMESPACE, poll_interval=POLL_INTERVAL
)
mock_job.to_dict.assert_called_once()
mock_is_job_failed.assert_called_once_with(job=mock_job)
assert event_actual == TriggerEvent(
Expand Down