Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
from typing import TYPE_CHECKING, Any, Generator

import aiofiles
from asgiref.sync import sync_to_async
from kubernetes import client, config, watch
from kubernetes.client.models import V1Pod
Expand Down Expand Up @@ -447,18 +448,54 @@ def _get_bool(val) -> bool | None:


class AsyncKubernetesHook(KubernetesHook):
"""Hook to use Kubernetes SDK asynchronously."""
"""
Creates Async Kubernetes API connection.

- use in cluster configuration by using extra field ``in_cluster`` in connection
- use custom config by providing path to the file using extra field ``kube_config_path`` in connection
- use custom configuration by providing content of kubeconfig file via
extra field ``kube_config`` in connection
- use default config by providing no extras

This hook check for configuration option in the above order. Once an option is present it will
use this configuration.

.. seealso::
For more information about Kubernetes connection:
:doc:`/connections/kubernetes`

:param conn_id: The :ref:`kubernetes connection <howto/connection:kubernetes>`
to Kubernetes cluster.
:param client_configuration: Optional dictionary of client configuration params.
Passed on to kubernetes client.
:param cluster_context: Optionally specify a context to use (e.g. if you have multiple
in your kubeconfig.
:param config_file: Path to kubeconfig file.
:param in_cluster: Set to ``True`` if running from within a kubernetes cluster.
:param disable_verify_ssl: Set to ``True`` if SSL verification should be disabled.
:param disable_tcp_keepalive: Set to ``True`` if you want to disable keepalive logic.
"""

def __init__(self, config_dict: dict | None = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config_dict = config_dict

self._extras: dict | None = None

async def _load_config(self):
"""Returns Kubernetes API session for use with requests"""
async def _load_config(self) -> async_client.ApiClient:
"""
Load config to interact with Kubernetes

cluster_context: Optional[str] = None,
config_file: Optional[str] = None,
in_cluster: Optional[bool] = None,

"""
in_cluster = self._coalesce_param(self.in_cluster, await self._get_field("in_cluster"))
cluster_context = self._coalesce_param(self.cluster_context, await self._get_field("cluster_context"))
kubeconfig_path = self._coalesce_param(
self.config_file, await self._get_field("kube_config_path") or None
)
kubeconfig = await self._get_field("kube_config")

num_selected_configuration = len([o for o in [in_cluster, kubeconfig, self.config_dict] if o])
Expand All @@ -481,15 +518,20 @@ async def _load_config(self):
await async_config.load_kube_config_from_dict(self.config_dict)
return async_client.ApiClient()

if kubeconfig_path is not None:
self.log.debug("loading kube_config from: %s", kubeconfig_path)
await async_config.load_kube_config(
config_file=kubeconfig_path,
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()

if kubeconfig is not None:
with tempfile.NamedTemporaryFile() as temp_config:
self.log.debug(
"Reading kubernetes configuration file from connection "
"object and writing temporary config file with its content",
)
temp_config.write(kubeconfig.encode())
temp_config.flush()
self._is_in_cluster = False
async with aiofiles.tempfile.NamedTemporaryFile() as temp_config:
self.log.debug("loading kube_config from: %s", kubeconfig_path)
await temp_config.write(kubeconfig.encode())
await temp_config.flush()
await async_config.load_kube_config(
config_file=temp_config.name,
client_configuration=self.client_configuration,
Expand Down Expand Up @@ -591,3 +633,10 @@ async def read_logs(self, name: str, namespace: str):
except HTTPError:
self.log.exception("There was an error reading the kubernetes API.")
raise

async def get_api_client_async(self) -> async_client.ApiClient:
"""Create an API Client object to interact with Kubernetes"""
kube_client = await self._load_config()
if kube_client is not None:
return kube_client
return async_client.ApiClient()
177 changes: 115 additions & 62 deletions airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@
from typing import TYPE_CHECKING, Any, Sequence

from kubernetes.client import CoreV1Api, models as k8s
from pendulum import DateTime
from slugify import slugify
from urllib3.exceptions import HTTPError

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.kubernetes.secret import Secret
Expand All @@ -50,7 +51,10 @@
convert_volume_mount,
)
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.triggers.kubernetes_pod import KubernetesPodTrigger
from airflow.providers.cncf.kubernetes.triggers.wait_container import (
PodLaunchTimeoutException,
WaitContainerTrigger,
)
from airflow.providers.cncf.kubernetes.utils import xcom_sidecar # type: ignore[attr-defined]
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
PodLaunchFailedException,
Expand All @@ -61,7 +65,6 @@
from airflow.settings import pod_mutation_hook
from airflow.utils import yaml
from airflow.utils.helpers import prune_dict, validate_key
from airflow.utils.timezone import utcnow
from airflow.version import version as airflow_version

if TYPE_CHECKING:
Expand Down Expand Up @@ -290,6 +293,7 @@ def __init__(
base_container_name: str | None = None,
deferrable: bool = False,
poll_interval: float = 2,
logging_interval: int | None = None,
**kwargs,
) -> None:
# TODO: remove in provider 6.0.0 release. This is a mitigate step to advise users to switch to the
Expand Down Expand Up @@ -359,6 +363,7 @@ def __init__(
self.base_container_name = base_container_name or self.BASE_CONTAINER_NAME
self.deferrable = deferrable
self.poll_interval = poll_interval
self.logging_interval = logging_interval
self.remote_pod: k8s.V1Pod | None = None

self._config_dict: dict | None = None
Expand Down Expand Up @@ -561,76 +566,120 @@ def execute_async(self, context: Context):
context=context,
)
self.convert_config_file_to_dict()
self.invoke_defer_method()
self.defer()

def convert_config_file_to_dict(self):
"""Converts passed config_file to dict format."""
config_file = self.config_file if self.config_file else os.environ.get(KUBE_CONFIG_ENV_VAR)
if config_file:
with open(config_file) as f:
self._config_dict = yaml.safe_load(f)
else:
self._config_dict = None

def invoke_defer_method(self):
"""Method to easily redefine triggers which are being used in child classes."""
trigger_start_time = utcnow()
self.defer(
trigger=KubernetesPodTrigger(
def defer(self, last_log_time: DateTime | None = None, **kwargs: Any) -> None:
"""Defers to ``WaitContainerTrigger`` optionally with last log time."""
if kwargs:
raise ValueError(
f"Received keyword arguments {list(kwargs.keys())} but "
f"they are not used in this implementation of `defer`."
)
super().defer(
trigger=WaitContainerTrigger(
kubernetes_conn_id=self.kubernetes_conn_id,
hook_params={
"cluster_context": self.cluster_context,
"config_file": self.config_file,
"in_cluster": self.in_cluster,
},
pod_name=self.pod.metadata.name,
container_name=self.BASE_CONTAINER_NAME,
pod_namespace=self.pod.metadata.namespace,
trigger_start_time=trigger_start_time,
kubernetes_conn_id=self.kubernetes_conn_id,
cluster_context=self.cluster_context,
config_dict=self._config_dict,
in_cluster=self.in_cluster,
pending_phase_timeout=self.startup_timeout_seconds,
poll_interval=self.poll_interval,
should_delete_pod=self.is_delete_operator_pod,
get_logs=self.get_logs,
startup_timeout=self.startup_timeout_seconds,
base_container_name=self.base_container_name,
logging_interval=self.logging_interval,
last_log_time=last_log_time,
),
method_name="execute_complete",
method_name=self.trigger_reentry.__name__,
)

def execute_complete(self, context: Context, event: dict, **kwargs):
pod = None
@staticmethod
def raise_for_trigger_status(event: dict[str, Any]) -> None:
"""Raise exception if pod is not in expected state."""
if event["status"] == "error":
error_type = event["error_type"]
description = event["description"]
if error_type == "PodLaunchTimeoutException":
raise PodLaunchTimeoutException(description)
else:
raise AirflowException(description)

def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
"""
Point of re-entry from trigger.

If ``logging_interval`` is None, then at this point the pod should be done and we'll just fetch
the logs and exit.

If ``logging_interval`` is not None, it could be that the pod is still running and we'll just
grab the latest logs and defer back to the trigger again.
"""
remote_pod = None
try:
pod = self.hook.get_pod(
event["name"],
event["namespace"],
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod = self.find_pod(
namespace=self.namespace or self.pod_request_obj.metadata.namespace,
context=context,
)
# It is done to coincide with the current implementation of the general logic of the cleanup
# method. If it's going to be remade in future then it must be changed
remote_pod = pod
if event["status"] in ("error", "failed", "timeout"):
# fetch some logs when pod is failed
if self.get_logs:
self.write_logs(pod)
raise AirflowException(event["message"])
elif event["status"] == "success":
ti = context["ti"]
ti.xcom_push(key="pod_name", value=pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=pod.metadata.namespace)

# fetch some logs when pod is executed successfully
if self.get_logs:
self.write_logs(pod)

if self.do_xcom_push:
xcom_sidecar_output = self.extract_xcom(pod=pod)
pod = self.pod_manager.await_pod_completion(pod)
# It is done to coincide with the current implementation of the general logic of
# the cleanup method. If it's going to be remade in future then it must be changed
remote_pod = pod
return xcom_sidecar_output
finally:
if pod is not None and remote_pod is not None:
self.post_complete_action(
pod=pod,
remote_pod=remote_pod,

# we try to find pod before possibly raising so that on_kill will have `pod` attr
self.raise_for_trigger_status(event)

if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")

if self.get_logs:
last_log_time = event and event.get("last_log_time")
if last_log_time:
self.log.info("Resuming logs read from time %r", last_log_time)
pod_log_status = self.pod_manager.fetch_container_logs(
pod=self.pod,
container_name=self.BASE_CONTAINER_NAME,
follow=self.logging_interval is None,
since_time=last_log_time,
)
if pod_log_status.running:
self.log.info("Container still running; deferring again.")
self.defer(pod_log_status.last_log_time)

if self.do_xcom_push:
result = self.extract_xcom(pod=self.pod)
remote_pod = self.pod_manager.await_pod_completion(self.pod)
except TaskDeferred:
raise
except Exception:
self.cleanup(
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
)
raise
self.cleanup(
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
)
ti = context["ti"]
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
if self.do_xcom_push:
return result

def convert_config_file_to_dict(self):
"""Converts passed config_file to dict format."""
config_file = self.config_file if self.config_file else os.environ.get(KUBE_CONFIG_ENV_VAR)
if config_file:
with open(config_file) as f:
self._config_dict = yaml.safe_load(f)
else:
self._config_dict = None

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""Deprecated; replaced by trigger_reentry."""
warnings.warn(
"Method `execute_complete` is deprecated and replaced with method `trigger_reentry`.",
DeprecationWarning,
)
self.trigger_reentry(context=context, event=event)

def write_logs(self, pod: k8s.V1Pod):
try:
Expand Down Expand Up @@ -873,3 +922,7 @@ def __exit__(self, exctype, excinst, exctb):
return True
else:
return True


class PodNotFoundException(AirflowException):
"""Expected pod does not exist in kube-api."""
Loading