Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b841b49
Create a generic callbacks class for KubernetesPodOperator
hussein-awala Nov 17, 2023
dde518d
Merge branch 'main' into kpo_callbacks
hussein-awala Nov 28, 2023
2797aef
Trigger tests with old-style union
hussein-awala Nov 28, 2023
6b884f0
Fix GCP K8S test
hussein-awala Nov 28, 2023
3ddfd82
Fix callback param type and cleanup calls
hussein-awala Nov 28, 2023
d245298
Some fixes and add unit tests
hussein-awala Nov 28, 2023
f722e40
Replace type by Type
hussein-awala Nov 28, 2023
c4ae081
Reset mock_callbacks in pod manager tests
hussein-awala Nov 28, 2023
c13b95e
Merge branch 'main' into kpo_callbacks
hussein-awala Jan 9, 2024
6173c1b
Fix static checks
hussein-awala Jan 9, 2024
b1755af
Add a doc paragraph for the new callbacks
hussein-awala Jan 10, 2024
1860881
Add a check for cncf-kuberntes version in google provider
hussein-awala Jan 10, 2024
479b3c8
Merge branch 'main' into kpo_callbacks
hussein-awala Jan 11, 2024
d9fd475
Switch to None default value and bump min cncf-k8s provider in google…
hussein-awala Jan 12, 2024
255b060
Merge branch 'main' into kpo_callbacks
hussein-awala Jan 12, 2024
39ece30
Fix tests
hussein-awala Jan 13, 2024
1d133d7
Reduce check intervals to avoid killing the asyncio task
hussein-awala Jan 13, 2024
be57a68
Merge branch 'main' into kpo_callbacks
hussein-awala Jan 19, 2024
79354e8
Revert async callbacks
hussein-awala Jan 20, 2024
f56a925
fix breeze tests
hussein-awala Jan 20, 2024
3470855
Merge branch 'main' into kpo_callbacks
hussein-awala Jan 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions airflow/providers/cncf/kubernetes/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from enum import Enum
from typing import Union

import kubernetes.client as k8s
import kubernetes_asyncio.client as async_k8s

client_type = Union[k8s.CoreV1Api, async_k8s.CoreV1Api]


class ExecutionMode(str, Enum):
"""Enum class for execution mode."""

SYNC = "sync"
ASYNC = "async"


class KubernetesPodOperatorCallback:
"""`KubernetesPodOperator` callbacks methods.

Currently, the callbacks methods are not called in the async mode, this support will be added
in the future.
"""

@staticmethod
def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None:
"""Callback method called after creating the sync client.

:param client: the created `kubernetes.client.CoreV1Api` client.
"""
pass

@staticmethod
def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
"""Callback method called after creating the pod.

:param pod: the created pod.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
"""Callback method called when the pod starts.

:param pod: the started pod.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
"""Callback method called when the pod completes.

:param pod: the completed pod.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs):
"""Callback method called after cleaning/deleting the pod.

:param pod: the completed pod.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_operator_resuming(
*, pod: k8s.V1Pod, event: dict, client: client_type, mode: str, **kwargs
) -> None:
"""Callback method called when resuming the `KubernetesPodOperator` from deferred state.

:param pod: the current state of the pod.
:param event: the returned event from the Trigger.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def progress_callback(*, line: str, client: client_type, mode: str, **kwargs) -> None:
"""Callback method to process pod container logs.

:param line: the read line of log.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass
48 changes: 45 additions & 3 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
convert_volume,
convert_volume_mount,
)
from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
POD_NAME_MAX_LENGTH,
Expand Down Expand Up @@ -198,7 +199,10 @@ class KubernetesPodOperator(BaseOperator):
Default value is "File"
:param active_deadline_seconds: The active_deadline_seconds which translates to active_deadline_seconds
in V1PodSpec.
:param callbacks: KubernetesPodOperatorCallback instance contains the callbacks methods on different step
of KubernetesPodOperator.
:param progress_callback: Callback function for receiving k8s container logs.
`progress_callback` is deprecated, please use :param `callbacks` instead.
"""

# !!! Changes in KubernetesPodOperator's arguments should be also reflected in !!!
Expand Down Expand Up @@ -290,6 +294,7 @@ def __init__(
is_delete_operator_pod: None | bool = None,
termination_message_policy: str = "File",
active_deadline_seconds: int | None = None,
callbacks: type[KubernetesPodOperatorCallback] | None = None,
progress_callback: Callable[[str], None] | None = None,
**kwargs,
) -> None:
Expand Down Expand Up @@ -381,6 +386,7 @@ def __init__(

self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict
self._progress_callback = progress_callback
self.callbacks = callbacks
self._killed: bool = False

@cached_property
Expand Down Expand Up @@ -459,7 +465,9 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool

@cached_property
def pod_manager(self) -> PodManager:
return PodManager(kube_client=self.client, progress_callback=self._progress_callback)
return PodManager(
kube_client=self.client, callbacks=self.callbacks, progress_callback=self._progress_callback
)

@cached_property
def hook(self) -> PodOperatorHookProtocol:
Expand All @@ -473,7 +481,10 @@ def hook(self) -> PodOperatorHookProtocol:

@cached_property
def client(self) -> CoreV1Api:
return self.hook.core_v1_client
client = self.hook.core_v1_client
if self.callbacks:
self.callbacks.on_sync_client_creation(client=client)
return client

def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None:
"""Return an already-running pod for this task instance if one exists."""
Expand Down Expand Up @@ -552,7 +563,17 @@ def execute_sync(self, context: Context):

# get remote pod for use in cleanup methods
self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context)
if self.callbacks:
self.callbacks.on_pod_creation(
pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC
)
self.await_pod_start(pod=self.pod)
if self.callbacks:
self.callbacks.on_pod_starting(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)

if self.get_logs:
self.pod_manager.fetch_requested_container_logs(
Expand All @@ -566,6 +587,12 @@ def execute_sync(self, context: Context):
self.pod_manager.await_container_completion(
pod=self.pod, container_name=self.base_container_name
)
if self.callbacks:
self.callbacks.on_pod_completion(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)

if self.do_xcom_push:
self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
Expand All @@ -575,10 +602,13 @@ def execute_sync(self, context: Context):
self.pod, istio_enabled, self.base_container_name
)
finally:
pod_to_clean = self.pod or self.pod_request_obj
self.cleanup(
pod=self.pod or self.pod_request_obj,
pod=pod_to_clean,
remote_pod=self.remote_pod,
)
if self.callbacks:
self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC)

if self.do_xcom_push:
return result
Expand All @@ -589,6 +619,12 @@ def execute_async(self, context: Context):
pod_request_obj=self.pod_request_obj,
context=context,
)
if self.callbacks:
self.callbacks.on_pod_creation(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)
ti = context["ti"]
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
Expand Down Expand Up @@ -625,6 +661,10 @@ def execute_complete(self, context: Context, event: dict, **kwargs):
event["name"],
event["namespace"],
)
if self.callbacks:
self.callbacks.on_operator_resuming(
pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC
)
if event["status"] in ("error", "failed", "timeout"):
# fetch some logs when pod is failed
if self.get_logs:
Expand Down Expand Up @@ -677,6 +717,8 @@ def post_complete_action(self, *, pod, remote_pod, **kwargs):
pod=pod,
remote_pod=remote_pod,
)
if self.callbacks:
self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC)

def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
# If a task got marked as failed, "on_kill" method would be called and the pod will be cleaned up
Expand Down
22 changes: 18 additions & 4 deletions airflow/providers/cncf/kubernetes/utils/pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from urllib3.exceptions import HTTPError, TimeoutError

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback
from airflow.providers.cncf.kubernetes.pod_generator import PodDefaults
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.timezone import utcnow
Expand All @@ -50,6 +51,7 @@
from kubernetes.client.models.v1_pod import V1Pod
from urllib3.response import HTTPResponse


EMPTY_XCOM_RESULT = "__airflow_xcom_result_empty__"
"""
Sentinel for no xcom result.
Expand Down Expand Up @@ -287,18 +289,22 @@ class PodManager(LoggingMixin):
def __init__(
self,
kube_client: client.CoreV1Api,
callbacks: type[KubernetesPodOperatorCallback] | None = None,
progress_callback: Callable[[str], None] | None = None,
):
"""
Create the launcher.

:param kube_client: kubernetes client
:param callbacks:
:param progress_callback: Callback function invoked when fetching container log.
This parameter is deprecated, please use ````
"""
super().__init__()
self._client = kube_client
self._progress_callback = progress_callback
self._watch = watch.Watch()
self._callbacks = callbacks

def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
"""Run POD asynchronously."""
Expand Down Expand Up @@ -441,9 +447,13 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None
message_timestamp = line_timestamp
progress_callback_lines.append(line)
else: # previous log line is complete
if self._progress_callback:
for line in progress_callback_lines:
for line in progress_callback_lines:
if self._progress_callback:
self._progress_callback(line)
if self._callbacks:
self._callbacks.progress_callback(
line=line, client=self._client, mode=ExecutionMode.SYNC
)
self.log.info("[%s] %s", container_name, message_to_log)
last_captured_timestamp = message_timestamp
message_to_log = message
Expand All @@ -454,9 +464,13 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None
progress_callback_lines.append(line)
finally:
# log the last line and update the last_captured_timestamp
if self._progress_callback:
for line in progress_callback_lines:
for line in progress_callback_lines:
if self._progress_callback:
self._progress_callback(line)
if self._callbacks:
self._callbacks.progress_callback(
line=line, client=self._client, mode=ExecutionMode.SYNC
)
self.log.info("[%s] %s", container_name, message_to_log)
last_captured_timestamp = message_timestamp
except TimeoutError as e:
Expand Down
Loading