Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions kubernetes_tests/test_kubernetes_pod_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,98 @@ class MyK8SPodOperator(KubernetesPodOperator):
)
assert MyK8SPodOperator(task_id=str(uuid4())).base_container_name == "tomato-sauce"

def test_init_container_logs(self, mock_get_connection):
marker_from_init_container = f"{uuid4()}"
marker_from_main_container = f"{uuid4()}"
callback = MagicMock()
init_container = k8s.V1Container(
name="init-container",
image="busybox",
command=["sh", "-cx"],
args=[f"echo {marker_from_init_container}"],
)
k = KubernetesPodOperator(
namespace="default",
image="busybox",
cmds=["sh", "-cx"],
arguments=[f"echo {marker_from_main_container}"],
labels=self.labels,
task_id=str(uuid4()),
in_cluster=False,
do_xcom_push=False,
startup_timeout_seconds=60,
init_containers=[init_container],
init_container_logs=True,
callbacks=callback,
)
context = create_context(k)
k.execute(context)

calls_args = "\n".join(["".join(c.kwargs["line"]) for c in callback.progress_callback.call_args_list])
assert marker_from_init_container in calls_args
assert marker_from_main_container in calls_args

def test_init_container_logs_filtered(self, mock_get_connection):
marker_from_init_container_to_log_1 = f"{uuid4()}"
marker_from_init_container_to_log_2 = f"{uuid4()}"
marker_from_init_container_to_ignore = f"{uuid4()}"
marker_from_main_container = f"{uuid4()}"
callback = MagicMock()
init_container_to_log_1 = k8s.V1Container(
name="init-container-to-log-1",
image="busybox",
command=["sh", "-cx"],
args=[f"echo {marker_from_init_container_to_log_1}"],
)
init_container_to_log_2 = k8s.V1Container(
name="init-container-to-log-2",
image="busybox",
command=["sh", "-cx"],
args=[f"echo {marker_from_init_container_to_log_2}"],
)
init_container_to_ignore = k8s.V1Container(
name="init-container-to-ignore",
image="busybox",
command=["sh", "-cx"],
args=[f"echo {marker_from_init_container_to_ignore}"],
)
k = KubernetesPodOperator(
namespace="default",
image="busybox",
cmds=["sh", "-cx"],
arguments=[f"echo {marker_from_main_container}"],
labels=self.labels,
task_id=str(uuid4()),
in_cluster=False,
do_xcom_push=False,
startup_timeout_seconds=60,
init_containers=[
init_container_to_log_1,
init_container_to_log_2,
init_container_to_ignore,
],
init_container_logs=[
# not same order as defined in init_containers
"init-container-to-log-2",
"init-container-to-log-1",
],
callbacks=callback,
)
context = create_context(k)
k.execute(context)

calls_args = "\n".join(["".join(c.kwargs["line"]) for c in callback.progress_callback.call_args_list])
assert marker_from_init_container_to_log_1 in calls_args
assert marker_from_init_container_to_log_2 in calls_args
assert marker_from_init_container_to_ignore not in calls_args
assert marker_from_main_container in calls_args

assert (
calls_args.find(marker_from_init_container_to_log_1)
< calls_args.find(marker_from_init_container_to_log_2)
< calls_args.find(marker_from_main_container)
)


def test_hide_sensitive_field_in_templated_fields_on_error(caplog, monkeypatch):
logger = logging.getLogger("airflow.task")
Expand Down
49 changes: 39 additions & 10 deletions providers/src/airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ class KubernetesPodOperator(BaseOperator):
:param startup_timeout_seconds: timeout in seconds to startup the pod.
:param startup_check_interval_seconds: interval in seconds to check if the pod has already started
:param get_logs: get the stdout of the base container as logs of the tasks.
:param init_container_logs: list of init containers whose logs will be published to stdout
Takes a sequence of containers, a single container name or True. If True,
all the containers logs are published.
:param container_logs: list of containers whose logs will be published to stdout
Takes a sequence of containers, a single container name or True. If True,
all the containers logs are published. Works in conjunction with get_logs param.
Expand Down Expand Up @@ -278,6 +281,7 @@ def __init__(
startup_check_interval_seconds: int = 5,
get_logs: bool = True,
base_container_name: str | None = None,
init_container_logs: Iterable[str] | str | Literal[True] | None = None,
container_logs: Iterable[str] | str | Literal[True] | None = None,
image_pull_policy: str | None = None,
annotations: dict | None = None,
Expand Down Expand Up @@ -352,6 +356,7 @@ def __init__(
# Fallback to the class variable BASE_CONTAINER_NAME here instead of via default argument value
# in the init method signature, to be compatible with subclasses overloading the class variable value.
self.base_container_name = base_container_name or self.BASE_CONTAINER_NAME
self.init_container_logs = init_container_logs
self.container_logs = container_logs or self.base_container_name
self.image_pull_policy = image_pull_policy
self.node_selector = node_selector or {}
Expand Down Expand Up @@ -600,6 +605,9 @@ def execute_sync(self, context: Context):
self.callbacks.on_pod_creation(
pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC
)

self.await_init_containers_completion(pod=self.pod)

self.await_pod_start(pod=self.pod)
if self.callbacks:
self.callbacks.on_pod_starting(
Expand Down Expand Up @@ -635,6 +643,22 @@ def execute_sync(self, context: Context):
if self.do_xcom_push:
return result

@tenacity.retry(
wait=tenacity.wait_exponential(max=15),
retry=tenacity.retry_if_exception_type(PodCredentialsExpiredFailure),
reraise=True,
)
def await_init_containers_completion(self, pod: k8s.V1Pod):
try:
if self.init_container_logs:
self.pod_manager.fetch_requested_init_container_logs(
pod=pod,
init_containers=self.init_container_logs,
follow_logs=True,
)
except kubernetes.client.exceptions.ApiException as exc:
self._handle_api_exception(exc, pod)

@tenacity.retry(
wait=tenacity.wait_exponential(max=15),
retry=tenacity.retry_if_exception_type(PodCredentialsExpiredFailure),
Expand All @@ -653,16 +677,21 @@ def await_pod_completion(self, pod: k8s.V1Pod):
):
self.pod_manager.await_container_completion(pod=pod, container_name=self.base_container_name)
except kubernetes.client.exceptions.ApiException as exc:
if exc.status and str(exc.status) == "401":
self.log.warning(
"Failed to check container status due to permission error. Refreshing credentials and retrying."
)
self._refresh_cached_properties()
self.pod_manager.read_pod(
pod=pod
) # attempt using refreshed credentials, raises if still invalid
raise PodCredentialsExpiredFailure("Kubernetes credentials expired, retrying after refresh.")
raise exc
self._handle_api_exception(exc, pod)

def _handle_api_exception(
self,
exc: kubernetes.client.exceptions.ApiException,
pod: k8s.V1Pod,
):
if exc.status and str(exc.status) == "401":
self.log.warning(
"Failed to check container status due to permission error. Refreshing credentials and retrying."
)
self._refresh_cached_properties()
self.pod_manager.read_pod(pod=pod) # attempt using refreshed credentials, raises if still invalid
raise PodCredentialsExpiredFailure("Kubernetes credentials expired, retrying after refresh.")
raise exc

def _refresh_cached_properties(self):
del self.hook
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import enum
import itertools
import json
import math
import time
Expand Down Expand Up @@ -117,7 +118,13 @@ def get_xcom_sidecar_container_resources(self) -> str | None:

def get_container_status(pod: V1Pod, container_name: str) -> V1ContainerStatus | None:
"""Retrieve container status."""
container_statuses = pod.status.container_statuses if pod and pod.status else None
if pod and pod.status:
container_statuses = itertools.chain(
pod.status.container_statuses, pod.status.init_container_statuses
)
else:
container_statuses = None

if container_statuses:
# In general the variable container_statuses can store multiple items matching different containers.
# The following generator expression yields all items that have name equal to the container_name.
Expand Down Expand Up @@ -166,6 +173,19 @@ def container_is_succeeded(pod: V1Pod, container_name: str) -> bool:
return container_status.state.terminated.exit_code == 0


def container_is_wait(pod: V1Pod, container_name: str) -> bool:
"""
Examine V1Pod ``pod`` to determine whether ``container_name`` is waiting.

If that container is present and waiting, returns True. Returns False otherwise.
"""
container_status = get_container_status(pod, container_name)
if not container_status:
return False

return container_status.state.waiting is not None


def container_is_terminated(pod: V1Pod, container_name: str) -> bool:
"""
Examine V1Pod ``pod`` to determine whether ``container_name`` is terminated.
Expand Down Expand Up @@ -509,7 +529,7 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None
time.sleep(1)

def _reconcile_requested_log_containers(
self, requested: Iterable[str] | str | bool, actual: list[str], pod_name
self, requested: Iterable[str] | str | bool | None, actual: list[str], pod_name
) -> list[str]:
"""Return actual containers based on requested."""
containers_to_log = []
Expand Down Expand Up @@ -552,6 +572,31 @@ def _reconcile_requested_log_containers(
self.log.error("Could not retrieve containers for the pod: %s", pod_name)
return containers_to_log

def fetch_requested_init_container_logs(
self, pod: V1Pod, init_containers: Iterable[str] | str | Literal[True] | None, follow_logs=False
) -> list[PodLoggingStatus]:
"""
Follow the logs of containers in the specified pod and publish it to airflow logging.

Returns when all the containers exit.

:meta private:
"""
pod_logging_statuses = []
all_containers = self.get_init_container_names(pod)
containers_to_log = self._reconcile_requested_log_containers(
requested=init_containers,
actual=all_containers,
pod_name=pod.metadata.name,
)
# sort by spec.initContainers because containers runs sequentially
containers_to_log = sorted(containers_to_log, key=lambda cn: all_containers.index(cn))
for c in containers_to_log:
self._await_init_container_start(pod=pod, container_name=c)
status = self.fetch_container_logs(pod=pod, container_name=c, follow=follow_logs)
pod_logging_statuses.append(status)
return pod_logging_statuses

def fetch_requested_container_logs(
self, pod: V1Pod, containers: Iterable[str] | str | Literal[True], follow_logs=False
) -> list[PodLoggingStatus]:
Expand Down Expand Up @@ -679,9 +724,22 @@ def read_pod_logs(
post_termination_timeout=post_termination_timeout,
)

@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def get_init_container_names(self, pod: V1Pod) -> list[str]:
"""
Return container names from the POD except for the airflow-xcom-sidecar container.

:meta private:
"""
return [container_spec.name for container_spec in pod.spec.init_containers]

@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def get_container_names(self, pod: V1Pod) -> list[str]:
"""Return container names from the POD except for the airflow-xcom-sidecar container."""
"""
Return container names from the POD except for the airflow-xcom-sidecar container.

:meta private:
"""
pod_info = self.read_pod(pod)
return [
container_spec.name
Expand Down Expand Up @@ -819,6 +877,20 @@ def _exec_pod_command(self, resp, command: str) -> str | None:
return res
return None

def _await_init_container_start(self, pod: V1Pod, container_name: str):
while True:
remote_pod = self.read_pod(pod)

if (
remote_pod.status is not None
and remote_pod.status.phase != PodPhase.PENDING
and get_container_status(remote_pod, container_name) is not None
and not container_is_wait(remote_pod, container_name)
):
return

time.sleep(1)


class OnFinishAction(str, enum.Enum):
"""Action to take when the pod finishes."""
Expand Down
4 changes: 4 additions & 0 deletions providers/tests/cncf/kubernetes/utils/test_pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ def remote_pod(running=None, not_running=None):
e = RemotePodMock()
e.status = RemotePodMock()
e.status.container_statuses = []
e.status.init_container_statuses = []
for r in not_running or []:
e.status.container_statuses.append(container(r, False))
for r in running or []:
Expand All @@ -643,6 +644,7 @@ def container(name, running):
p = RemotePodMock()
p.status = RemotePodMock()
p.status.container_statuses = []
p.status.init_container_statuses = []
pod_mock_list.append(pytest.param(p, False, id="empty remote_pod.status.container_statuses"))
pod_mock_list.append(pytest.param(remote_pod(), False, id="filter empty"))
pod_mock_list.append(pytest.param(remote_pod(None, ["base"]), False, id="filter 0 running"))
Expand Down Expand Up @@ -858,6 +860,7 @@ def remote_pod(succeeded=None, not_succeeded=None):
e = RemotePodMock()
e.status = RemotePodMock()
e.status.container_statuses = []
e.status.init_container_statuses = []
for r in not_succeeded or []:
e.status.container_statuses.append(container(r, False))
for r in succeeded or []:
Expand All @@ -878,6 +881,7 @@ def container(name, succeeded):
p = RemotePodMock()
p.status = RemotePodMock()
p.status.container_statuses = []
p.status.init_container_statuses = []
pod_mock_list.append(pytest.param(p, False, id="empty remote_pod.status.container_statuses"))
pod_mock_list.append(pytest.param(remote_pod(), False, id="filter empty"))
pod_mock_list.append(pytest.param(remote_pod(None, ["base"]), False, id="filter 0 succeeded"))
Expand Down