Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
# under the License.
from __future__ import annotations

import asyncio
import json
import logging
import os
import shutil
from contextlib import nullcontext
from copy import copy
from unittest import mock
from unittest.mock import ANY, MagicMock
from unittest.mock import ANY, AsyncMock, MagicMock
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -187,7 +186,6 @@ def test_do_xcom_push_defaults_false(self, kubeconfig_path, mock_get_connection,
)
assert not k.do_xcom_push

@pytest.mark.asyncio
def test_config_path_move(self, kubeconfig_path, mock_get_connection, tmp_path):
new_config_path = tmp_path / "kube_config.cfg"
shutil.copy(kubeconfig_path, new_config_path)
Expand All @@ -210,7 +208,6 @@ def test_config_path_move(self, kubeconfig_path, mock_get_connection, tmp_path):
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
assert actual_pod == expected_pod

@pytest.mark.asyncio
def test_working_pod(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand All @@ -228,7 +225,6 @@ def test_working_pod(self, mock_get_connection):
assert self.expected_pod["spec"] == actual_pod["spec"]
assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"]

@pytest.mark.asyncio
def test_skip_cleanup(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="unknown",
Expand All @@ -244,7 +240,6 @@ def test_skip_cleanup(self, mock_get_connection):
with pytest.raises(ApiException):
k.execute(context)

@pytest.mark.asyncio
def test_delete_operator_pod(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand All @@ -263,7 +258,6 @@ def test_delete_operator_pod(self, mock_get_connection):
assert self.expected_pod["spec"] == actual_pod["spec"]
assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"]

@pytest.mark.asyncio
def test_skip_on_specified_exit_code(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand All @@ -280,7 +274,6 @@ def test_skip_on_specified_exit_code(self, mock_get_connection):
with pytest.raises(AirflowSkipException):
k.execute(context)

@pytest.mark.asyncio
def test_already_checked_on_success(self, mock_get_connection):
"""
When ``on_finish_action="keep_pod"``, pod should have 'already_checked'
Expand All @@ -303,7 +296,6 @@ def test_already_checked_on_success(self, mock_get_connection):
actual_pod = self.api_client.sanitize_for_serialization(actual_pod)
assert actual_pod["metadata"]["labels"]["already_checked"] == "True"

@pytest.mark.asyncio
def test_already_checked_on_failure(self, mock_get_connection):
"""
When ``on_finish_action="keep_pod"``, pod should have 'already_checked'
Expand All @@ -329,7 +321,6 @@ def test_already_checked_on_failure(self, mock_get_connection):
assert status["state"]["terminated"]["reason"] == "Error"
assert actual_pod["metadata"]["labels"]["already_checked"] == "True"

@pytest.mark.asyncio
def test_pod_hostnetwork(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand All @@ -349,7 +340,6 @@ def test_pod_hostnetwork(self, mock_get_connection):
assert self.expected_pod["spec"] == actual_pod["spec"]
assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"]

@pytest.mark.asyncio
def test_pod_dnspolicy(self, mock_get_connection):
dns_policy = "ClusterFirstWithHostNet"
k = KubernetesPodOperator(
Expand All @@ -372,7 +362,6 @@ def test_pod_dnspolicy(self, mock_get_connection):
assert self.expected_pod["spec"] == actual_pod["spec"]
assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"]

@pytest.mark.asyncio
def test_pod_schedulername(self, mock_get_connection):
scheduler_name = "default-scheduler"
k = KubernetesPodOperator(
Expand All @@ -392,7 +381,6 @@ def test_pod_schedulername(self, mock_get_connection):
self.expected_pod["spec"]["schedulerName"] = scheduler_name
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_pod_node_selector(self, mock_get_connection):
node_selector = {"beta.kubernetes.io/os": "linux"}
k = KubernetesPodOperator(
Expand All @@ -412,7 +400,6 @@ def test_pod_node_selector(self, mock_get_connection):
self.expected_pod["spec"]["nodeSelector"] = node_selector
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_pod_resources(self, mock_get_connection):
resources = k8s.V1ResourceRequirements(
requests={"memory": "64Mi", "cpu": "250m", "ephemeral-storage": "1Gi"},
Expand All @@ -438,7 +425,6 @@ def test_pod_resources(self, mock_get_connection):
}
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
@pytest.mark.parametrize(
"val",
[
Expand Down Expand Up @@ -515,7 +501,6 @@ def test_pod_affinity(self, val, mock_get_connection):
self.expected_pod["spec"]["affinity"] = expected
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_port(self, mock_get_connection):
port = k8s.V1ContainerPort(
name="http",
Expand All @@ -539,7 +524,6 @@ def test_port(self, mock_get_connection):
self.expected_pod["spec"]["containers"][0]["ports"] = [{"name": "http", "containerPort": 80}]
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_volume_mount(self, mock_get_connection):
with mock.patch.object(PodManager, "log") as mock_logger:
volume_mount = k8s.V1VolumeMount(
Expand Down Expand Up @@ -579,7 +563,6 @@ def test_volume_mount(self, mock_get_connection):
]
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
@pytest.mark.parametrize("uid", [0, 1000])
def test_run_as_user(self, uid, mock_get_connection):
security_context = V1PodSecurityContext(run_as_user=uid)
Expand All @@ -605,7 +588,6 @@ def test_run_as_user(self, uid, mock_get_connection):
)
assert pod.to_dict()["spec"]["security_context"]["run_as_user"] == uid

@pytest.mark.asyncio
@pytest.mark.parametrize("gid", [0, 1000])
def test_fs_group(self, gid, mock_get_connection):
security_context = V1PodSecurityContext(fs_group=gid)
Expand All @@ -631,7 +613,6 @@ def test_fs_group(self, gid, mock_get_connection):
)
assert pod.to_dict()["spec"]["security_context"]["fs_group"] == gid

@pytest.mark.asyncio
def test_disable_privilege_escalation(self, mock_get_connection):
container_security_context = V1SecurityContext(allow_privilege_escalation=False)

Expand All @@ -654,7 +635,6 @@ def test_disable_privilege_escalation(self, mock_get_connection):
}
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_faulty_image(self, mock_get_connection):
bad_image_name = "foobar"
k = KubernetesPodOperator(
Expand Down Expand Up @@ -693,7 +673,6 @@ def test_faulty_service_account(self, mock_get_connection):
with pytest.raises(ApiException, match="error looking up service account default/foobar"):
k.get_or_create_pod(pod, context)

@pytest.mark.asyncio
def test_pod_failure(self, mock_get_connection):
"""
Tests that the task fails when a pod reports a failure
Expand All @@ -716,7 +695,6 @@ def test_pod_failure(self, mock_get_connection):
self.expected_pod["spec"]["containers"][0]["args"] = bad_internal_command
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_xcom_push(self, test_label, mock_get_connection):
expected = {"test_label": test_label, "buzz": 2}
args = [f"echo '{json.dumps(expected)}' > /airflow/xcom/return.json"]
Expand Down Expand Up @@ -765,7 +743,6 @@ def test_env_vars(self, mock_get_connection):
]
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_pod_template_file_system(self, mock_get_connection, basic_pod_template):
"""Note: this test requires that you have a namespace ``mem-example`` in your cluster."""
k = KubernetesPodOperator(
Expand All @@ -781,7 +758,6 @@ def test_pod_template_file_system(self, mock_get_connection, basic_pod_template)
assert result is not None
assert result == {"hello": "world"}

@pytest.mark.asyncio
@pytest.mark.parametrize(
"env_vars",
[
Expand Down Expand Up @@ -817,7 +793,6 @@ def test_pod_template_file_with_overrides_system(
assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
assert result == {"hello": "world"}

@pytest.mark.asyncio
def test_pod_template_file_with_full_pod_spec(self, test_label, mock_get_connection, basic_pod_template):
pod_spec = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
Expand Down Expand Up @@ -858,7 +833,6 @@ def test_pod_template_file_with_full_pod_spec(self, test_label, mock_get_connect
assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
assert result == {"hello": "world"}

@pytest.mark.asyncio
def test_full_pod_spec(self, test_label, mock_get_connection):
pod_spec = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
Expand Down Expand Up @@ -904,7 +878,6 @@ def test_full_pod_spec(self, test_label, mock_get_connection):
assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
assert result == {"hello": "world"}

@pytest.mark.asyncio
def test_init_container(self, mock_get_connection):
# GIVEN
volume_mounts = [
Expand Down Expand Up @@ -959,10 +932,11 @@ def test_init_container(self, mock_get_connection):
]
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
@mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start")
@mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom")
@mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
@mock.patch(f"{POD_MANAGER_CLASS}.watch_pod_events", new=AsyncMock())
@mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start", new=AsyncMock())
@mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
@mock.patch(HOOK_CLASS)
def test_pod_template_file(
Expand Down Expand Up @@ -1067,8 +1041,9 @@ def test_pod_template_file(
del actual_pod["metadata"]["labels"]["airflow_version"]
assert expected_dict == actual_pod

@pytest.mark.asyncio
@mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
@mock.patch(f"{POD_MANAGER_CLASS}.watch_pod_events", new=AsyncMock())
@mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start", new=AsyncMock())
@mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
@mock.patch(HOOK_CLASS)
def test_pod_priority_class_name(self, hook_mock, await_pod_completion_mock):
Expand Down Expand Up @@ -1102,7 +1077,6 @@ def test_pod_priority_class_name(self, hook_mock, await_pod_completion_mock):
self.expected_pod["spec"]["priorityClassName"] = priority_class_name
assert self.expected_pod == actual_pod

@pytest.mark.asyncio
def test_pod_name(self, mock_get_connection):
pod_name_too_long = "a" * 221
k = KubernetesPodOperator(
Expand All @@ -1122,7 +1096,6 @@ def test_pod_name(self, mock_get_connection):
with pytest.raises(AirflowException):
k.execute(context)

@pytest.mark.asyncio
def test_on_kill(self, mock_get_connection):
hook = KubernetesHook(conn_id=None, in_cluster=False)
client = hook.core_v1_client
Expand Down Expand Up @@ -1163,7 +1136,6 @@ class ShortCircuitException(Exception):
with pytest.raises(ApiException, match=r'pods \\"test.[a-z0-9]+\\" not found'):
client.read_namespaced_pod(name=name, namespace=namespace)

@pytest.mark.asyncio
def test_reattach_failing_pod_once(self, mock_get_connection):
hook = KubernetesHook(conn_id=None, in_cluster=False)
client = hook.core_v1_client
Expand Down Expand Up @@ -1224,19 +1196,13 @@ def get_op():
# recreate op just to ensure we're not relying on any statefulness
k = get_op()

# Before next attempt we need to re-create event loop if it is closed.
loop = asyncio.get_event_loop()
if loop.is_closed():
asyncio.set_event_loop(asyncio.new_event_loop())

# `create_pod` should be called because though there's still a pod to be found,
# it will be `already_checked`
with mock.patch(f"{POD_MANAGER_CLASS}.create_pod") as create_mock:
with pytest.raises(ApiException, match=r'pods \\"test.[a-z0-9]+\\" not found'):
k.execute(context)
create_mock.assert_called_once()

@pytest.mark.asyncio
def test_changing_base_container_name_with_get_logs(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand All @@ -1262,7 +1228,6 @@ def test_changing_base_container_name_with_get_logs(self, mock_get_connection):
self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce"
assert self.expected_pod["spec"] == actual_pod["spec"]

@pytest.mark.asyncio
def test_changing_base_container_name_no_logs(self, mock_get_connection):
"""
This test checks BOTH a modified base container name AND the get_logs=False flow,
Expand Down Expand Up @@ -1293,7 +1258,6 @@ def test_changing_base_container_name_no_logs(self, mock_get_connection):
self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce"
assert self.expected_pod["spec"] == actual_pod["spec"]

@pytest.mark.asyncio
def test_changing_base_container_name_no_logs_long(self, mock_get_connection):
"""
Similar to test_changing_base_container_name_no_logs, but ensures that
Expand Down Expand Up @@ -1325,7 +1289,6 @@ def test_changing_base_container_name_no_logs_long(self, mock_get_connection):
self.expected_pod["spec"]["containers"][0]["args"] = ["sleep 3"]
assert self.expected_pod["spec"] == actual_pod["spec"]

@pytest.mark.asyncio
def test_changing_base_container_name_failure(self, mock_get_connection):
k = KubernetesPodOperator(
namespace="default",
Expand Down Expand Up @@ -1372,7 +1335,6 @@ class MyK8SPodOperator(KubernetesPodOperator):
)
assert MyK8SPodOperator(task_id=str(uuid4())).base_container_name == "tomato-sauce"

@pytest.mark.asyncio
def test_init_container_logs(self, mock_get_connection):
marker_from_init_container = f"{uuid4()}"
marker_from_main_container = f"{uuid4()}"
Expand Down Expand Up @@ -1404,7 +1366,6 @@ def test_init_container_logs(self, mock_get_connection):
assert marker_from_init_container in calls_args
assert marker_from_main_container in calls_args

@pytest.mark.asyncio
def test_init_container_logs_filtered(self, mock_get_connection):
marker_from_init_container_to_log_1 = f"{uuid4()}"
marker_from_init_container_to_log_2 = f"{uuid4()}"
Expand Down Expand Up @@ -1502,7 +1463,6 @@ def __getattr__(self, name):


class TestKubernetesPodOperator(BaseK8STest):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"active_deadline_seconds,should_fail",
[(3, True), (60, False)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,20 +608,18 @@ def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context: Context) -> k8s

def await_pod_start(self, pod: k8s.V1Pod) -> None:
try:
loop = asyncio.get_event_loop()
events_task = asyncio.ensure_future(
self.pod_manager.watch_pod_events(pod, self.startup_check_interval_seconds)
)
loop.run_until_complete(
self.pod_manager.await_pod_start(

async def _await_pod_start():
events_task = self.pod_manager.watch_pod_events(pod, self.startup_check_interval_seconds)
pod_start_task = self.pod_manager.await_pod_start(
pod=pod,
schedule_timeout=self.schedule_timeout_seconds,
startup_timeout=self.startup_timeout_seconds,
check_interval=self.startup_check_interval_seconds,
)
)
loop.run_until_complete(events_task)
loop.close()
await asyncio.gather(pod_start_task, events_task)

asyncio.run(_await_pod_start())
except PodLaunchFailedException:
if self.log_events_on_failure:
self._read_pod_events(pod, reraise=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __init__(
self._client = kube_client
self._watch = watch.Watch()
self._callbacks = callbacks or []
self.stop_watching_events = False

def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
"""Run POD asynchronously."""
Expand Down Expand Up @@ -380,9 +381,8 @@ def create_pod(self, pod: V1Pod) -> V1Pod:

async def watch_pod_events(self, pod: V1Pod, check_interval: int = 1) -> None:
"""Read pod events and writes into log."""
self.keep_watching_for_events = True
num_events = 0
while self.keep_watching_for_events:
while not self.stop_watching_events:
events = self.read_pod_events(pod)
for new_event in events.items[num_events:]:
involved_object: V1ObjectReference = new_event.involved_object
Expand Down Expand Up @@ -413,7 +413,7 @@ async def await_pod_start(
remote_pod = self.read_pod(pod)
pod_status = remote_pod.status
if pod_status.phase != PodPhase.PENDING:
self.keep_watching_for_events = False
self.stop_watching_events = True
self.log.info("::endgroup::")
break

Expand Down
Loading
Loading