diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 8d99ec0470d64..2759eef2ff347 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -2429,11 +2429,21 @@ - name: multi_namespace_mode description: | Allows users to launch pods in multiple namespaces. - Will require creating a cluster-role for the scheduler + Will require creating a cluster-role for the scheduler, + or use multi_namespace_mode_namespace_list configuration. version_added: 1.10.12 type: boolean example: ~ default: "False" + - name: multi_namespace_mode_namespace_list + description: | + If multi_namespace_mode is True while scheduler does not have a cluster-role, + give the list of namespaces where the scheduler will schedule jobs + Scheduler needs to have the necessary permissions in these namespaces. + version_added: 2.6.0 + type: string + example: ~ + default: "" - name: in_cluster description: | Use the service account kubernetes gives to pods to connect to kubernetes cluster. diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 9d4e863c6e5b7..365de48f50da7 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -1215,9 +1215,15 @@ delete_worker_pods_on_failure = False worker_pods_creation_batch_size = 1 # Allows users to launch pods in multiple namespaces. -# Will require creating a cluster-role for the scheduler +# Will require creating a cluster-role for the scheduler, +# or use multi_namespace_mode_namespace_list configuration. multi_namespace_mode = False +# If multi_namespace_mode is True while scheduler does not have a cluster-role, +# give the list of namespaces where the scheduler will schedule jobs +# Scheduler needs to have the necessary permissions in these namespaces. +multi_namespace_mode_namespace_list = + # Use the service account kubernetes gives to pods to connect to kubernetes cluster. # It's intended for clients that expect to be running inside a pod running on kubernetes. # It will raise an exception if called from a process not running in a kubernetes environment. diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 788f1b3fecb9a..d933a294e9cae 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -23,11 +23,11 @@ """ from __future__ import annotations -import functools import json import logging import multiprocessing import time +from collections import defaultdict from datetime import timedelta from queue import Empty, Queue from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple @@ -52,6 +52,8 @@ from airflow.utils.session import provide_session from airflow.utils.state import State +ALL_NAMESPACES = "ALL_NAMESPACES" + # TaskInstance key, command, configuration, pod_template_file KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]] @@ -66,7 +68,7 @@ class ResourceVersion: """Singleton for tracking resourceVersion from Kubernetes.""" _instance = None - resource_version = "0" + resource_version: dict[str, str] = {} def __new__(cls): if cls._instance is None: @@ -79,8 +81,7 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): def __init__( self, - namespace: str | None, - multi_namespace_mode: bool, + namespace: str, watcher_queue: Queue[KubernetesWatchType], resource_version: str | None, scheduler_job_id: str, @@ -88,7 +89,6 @@ def __init__( ): super().__init__() self.namespace = namespace - self.multi_namespace_mode = multi_namespace_mode self.scheduler_job_id = scheduler_job_id self.watcher_queue = watcher_queue self.resource_version = resource_version @@ -113,7 +113,7 @@ def run(self) -> None: except Exception: self.log.exception("Unknown error in KubernetesJobWatcher. Failing") self.resource_version = "0" - ResourceVersion().resource_version = "0" + ResourceVersion().resource_version[self.namespace] = "0" raise else: self.log.warning( @@ -121,6 +121,14 @@ def run(self) -> None: self.resource_version, ) + def _pod_events(self, kube_client: client.CoreV1Api, query_kwargs: dict): + watcher = watch.Watch() + + if self.namespace == ALL_NAMESPACES: + return watcher.stream(kube_client.list_pod_for_all_namespaces, **query_kwargs) + else: + return watcher.stream(kube_client.list_namespaced_pod, self.namespace, **query_kwargs) + def _run( self, kube_client: client.CoreV1Api, @@ -129,7 +137,6 @@ def _run( kube_config: Any, ) -> str | None: self.log.info("Event: and now my watch begins starting at resource_version: %s", resource_version) - watcher = watch.Watch() kwargs = {"label_selector": f"airflow-worker={scheduler_job_id}"} if resource_version: @@ -139,15 +146,8 @@ def _run( kwargs[key] = value last_resource_version: str | None = None - if self.multi_namespace_mode: - list_worker_pods = functools.partial( - watcher.stream, kube_client.list_pod_for_all_namespaces, **kwargs - ) - else: - list_worker_pods = functools.partial( - watcher.stream, kube_client.list_namespaced_pod, self.namespace, **kwargs - ) - for event in list_worker_pods(): + + for event in self._pod_events(kube_client=kube_client, query_kwargs=kwargs): task = event["object"] self.log.debug("Event: %s had an event of type %s", task.metadata.name, event["type"]) if event["type"] == "ERROR": @@ -251,7 +251,7 @@ def __init__( self._manager = multiprocessing.Manager() self.watcher_queue = self._manager.Queue() self.scheduler_job_id = scheduler_job_id - self.kube_watcher = self._make_kube_watcher() + self.kube_watchers = self._make_kube_watchers() def run_pod_async(self, pod: k8s.V1Pod, **kwargs): """Runs POD asynchronously.""" @@ -274,12 +274,11 @@ def run_pod_async(self, pod: k8s.V1Pod, **kwargs): raise e return resp - def _make_kube_watcher(self) -> KubernetesJobWatcher: - resource_version = ResourceVersion().resource_version + def _make_kube_watcher(self, namespace) -> KubernetesJobWatcher: + resource_version = ResourceVersion().resource_version.get(namespace, "0") watcher = KubernetesJobWatcher( watcher_queue=self.watcher_queue, - namespace=self.kube_config.kube_namespace, - multi_namespace_mode=self.kube_config.multi_namespace_mode, + namespace=namespace, resource_version=resource_version, scheduler_job_id=self.scheduler_job_id, kube_config=self.kube_config, @@ -287,15 +286,35 @@ def _make_kube_watcher(self) -> KubernetesJobWatcher: watcher.start() return watcher - def _health_check_kube_watcher(self): - if self.kube_watcher.is_alive(): - self.log.debug("KubeJobWatcher alive, continuing") - else: - self.log.error( - "Error while health checking kube watcher process. Process died for unknown reasons" + def _make_kube_watchers(self) -> dict[str, KubernetesJobWatcher]: + watchers = {} + if self.kube_config.multi_namespace_mode: + namespaces_to_watch = ( + self.kube_config.multi_namespace_mode_namespace_list + if self.kube_config.multi_namespace_mode_namespace_list + else [ALL_NAMESPACES] ) - ResourceVersion().resource_version = "0" - self.kube_watcher = self._make_kube_watcher() + else: + namespaces_to_watch = [self.kube_config.kube_namespace] + + for namespace in namespaces_to_watch: + watchers[namespace] = self._make_kube_watcher(namespace) + return watchers + + def _health_check_kube_watchers(self): + for namespace, kube_watcher in self.kube_watchers.items(): + if kube_watcher.is_alive(): + self.log.debug("KubeJobWatcher for namespace %s alive, continuing", namespace) + else: + self.log.error( + ( + "Error while health checking kube watcher process for namespace %s. " + "Process died for unknown reasons" + ), + namespace, + ) + ResourceVersion().resource_version[namespace] = "0" + self.kube_watchers[namespace] = self._make_kube_watcher(namespace) def run_next(self, next_job: KubernetesJobType) -> None: """Receives the next job to run, builds the pod, and creates it.""" @@ -363,7 +382,7 @@ def sync(self) -> None: """ self.log.debug("Syncing KubernetesExecutor") - self._health_check_kube_watcher() + self._health_check_kube_watchers() while True: try: task = self.watcher_queue.get_nowait() @@ -399,10 +418,11 @@ def _flush_watcher_queue(self) -> None: def terminate(self) -> None: """Terminates the watcher.""" - self.log.debug("Terminating kube_watcher...") - self.kube_watcher.terminate() - self.kube_watcher.join() - self.log.debug("kube_watcher=%s", self.kube_watcher) + self.log.debug("Terminating kube_watchers...") + for namespace, kube_watcher in self.kube_watchers.items(): + kube_watcher.terminate() + kube_watcher.join() + self.log.debug("kube_watcher=%s", kube_watcher) self.log.debug("Flushing watcher_queue...") self._flush_watcher_queue() # Queue should be empty... @@ -446,6 +466,23 @@ def __init__(self): self.kubernetes_queue: str | None = None super().__init__(parallelism=self.kube_config.parallelism) + def _list_pods(self, query_kwargs): + if self.kube_config.multi_namespace_mode: + if self.kube_config.multi_namespace_mode_namespace_list: + pods = [] + for namespace in self.kube_config.multi_namespace_mode_namespace_list: + pods.extend( + self.kube_client.list_namespaced_pod(namespace=namespace, **query_kwargs).items + ) + else: + pods = self.kube_client.list_pod_for_all_namespaces(**query_kwargs).items + else: + pods = self.kube_client.list_namespaced_pod( + namespace=self.kube_config.kube_namespace, **query_kwargs + ).items + + return pods + @provide_session def clear_not_launched_queued_tasks(self, session=None) -> None: """ @@ -501,16 +538,16 @@ def clear_not_launched_queued_tasks(self, session=None) -> None: # Try run_id first kwargs["label_selector"] += ",run_id=" + pod_generator.make_safe_label_value(ti.run_id) - pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs) - if pod_list.items: + pod_list = self._list_pods(kwargs) + if pod_list: continue # Fallback to old style of using execution_date kwargs["label_selector"] = ( f"{base_label_selector}," f"execution_date={pod_generator.datetime_to_label_safe_datestring(ti.execution_date)}" ) - pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs) - if pod_list.items: + pod_list = self._list_pods(kwargs) + if pod_list: continue self.log.info("TaskInstance: %s found in queued state but was not launched, rescheduling", ti) session.query(TaskInstance).filter( @@ -597,13 +634,13 @@ def sync(self) -> None: self.log.debug("self.queued: %s", self.queued_tasks) self.kube_scheduler.sync() - last_resource_version = None + last_resource_version: dict[str, str] = defaultdict(lambda: "0") while True: try: results = self.result_queue.get_nowait() try: key, state, pod_id, namespace, resource_version = results - last_resource_version = resource_version + last_resource_version[namespace] = resource_version self.log.info("Changing state of %s to %s", results, state) try: self._change_state(key, state, pod_id, namespace) @@ -621,7 +658,10 @@ def sync(self) -> None: break resource_instance = ResourceVersion() - resource_instance.resource_version = last_resource_version or resource_instance.resource_version + for ns in resource_instance.resource_version.keys(): + resource_instance.resource_version[ns] = ( + last_resource_version[ns] or resource_instance.resource_version[ns] + ) for _ in range(self.kube_config.worker_pods_creation_batch_size): try: @@ -681,15 +721,10 @@ def _check_worker_pods_pending_timeout(self): "label_selector": f"airflow-worker={self.scheduler_job_id}", **self.kube_config.kube_client_request_args, } - if self.kube_config.multi_namespace_mode: - pending_pods = functools.partial(self.kube_client.list_pod_for_all_namespaces, **kwargs) - else: - pending_pods = functools.partial( - self.kube_client.list_namespaced_pod, self.kube_config.kube_namespace, **kwargs - ) + pending_pods = self._list_pods(kwargs) cutoff = timezone.utcnow() - timedelta(seconds=timeout) - for pod in pending_pods().items: + for pod in pending_pods: self.log.debug( 'Found a pending pod "%s", created "%s"', pod.metadata.name, pod.metadata.creation_timestamp ) @@ -726,9 +761,9 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task kube_client: client.CoreV1Api = self.kube_client for scheduler_job_id in scheduler_job_ids: scheduler_job_id = pod_generator.make_safe_label_value(str(scheduler_job_id)) - kwargs = {"label_selector": f"airflow-worker={scheduler_job_id}"} - pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs) - for pod in pod_list.items: + query_kwargs = {"label_selector": f"airflow-worker={scheduler_job_id}"} + pod_list = self._list_pods(query_kwargs) + for pod in pod_list: self.adopt_launched_task(kube_client, pod, pod_ids) self._adopt_completed_pods(kube_client) tis_to_flush.extend(pod_ids.values()) @@ -775,12 +810,12 @@ def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None: assert self.scheduler_job_id new_worker_id_label = pod_generator.make_safe_label_value(self.scheduler_job_id) - kwargs = { + query_kwargs = { "field_selector": "status.phase=Succeeded", "label_selector": f"kubernetes_executor=True,airflow-worker!={new_worker_id_label}", } - pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs) - for pod in pod_list.items: + pod_list = self._list_pods(query_kwargs) + for pod in pod_list: self.log.info("Attempting to adopt pod %s", pod.metadata.name) pod.metadata.labels["airflow-worker"] = new_worker_id_label try: diff --git a/airflow/kubernetes/kube_config.py b/airflow/kubernetes/kube_config.py index 0285f65208d01..8d2aa9c2faf5b 100644 --- a/airflow/kubernetes/kube_config.py +++ b/airflow/kubernetes/kube_config.py @@ -57,6 +57,14 @@ def __init__(self): # create, watch, get, and delete pods in this namespace. self.kube_namespace = conf.get(self.kubernetes_section, "namespace") self.multi_namespace_mode = conf.getboolean(self.kubernetes_section, "multi_namespace_mode") + if self.multi_namespace_mode and conf.get( + self.kubernetes_section, "multi_namespace_mode_namespace_list" + ): + self.multi_namespace_mode_namespace_list = conf.get( + self.kubernetes_section, "multi_namespace_mode_namespace_list" + ).split(",") + else: + self.multi_namespace_mode_namespace_list = None # The Kubernetes Namespace in which pods will be created by the executor. Note # that if your # cluster has RBAC enabled, your workers may need service account permissions to diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 367f1cb2c4f4a..cee1f77202ccc 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -521,6 +521,27 @@ def test_change_state_failed_no_deletion( assert executor.event_buffer[key][0] == State.FAILED mock_delete_pod.assert_not_called() + @pytest.mark.parametrize( + "multi_namespace_mode_namespace_list, watchers_keys", + [ + pytest.param(["A", "B", "C"], ["A", "B", "C"]), + pytest.param(None, ["ALL_NAMESPACES"]), + ], + ) + @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + def test_watchers_under_multi_namespace_mode( + self, mock_get_kube_client, multi_namespace_mode_namespace_list, watchers_keys + ): + executor = self.kubernetes_executor + executor.kube_config.multi_namespace_mode = True + executor.kube_config.multi_namespace_mode_namespace_list = multi_namespace_mode_namespace_list + executor.start() + assert list(executor.kube_scheduler.kube_watchers.keys()) == watchers_keys + assert all( + isinstance(v, KubernetesJobWatcher) for v in executor.kube_scheduler.kube_watchers.values() + ) + executor.end() + @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") @mock.patch("airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod") @@ -694,6 +715,27 @@ def test_not_adopt_unassigned_task(self, mock_kube_client): assert not mock_kube_client.patch_namespaced_pod.called assert pod_ids == {"foobar": {}} + @pytest.mark.parametrize( + "raw_multi_namespace_mode, raw_value_namespace_list, expected_value_in_kube_config", + [ + pytest.param("true", "A,B,C", ["A", "B", "C"]), + pytest.param("true", "", None), + pytest.param("false", "A,B,C", None), + pytest.param("false", "", None), + ], + ) + def test_kube_config_get_namespace_list( + self, raw_multi_namespace_mode, raw_value_namespace_list, expected_value_in_kube_config + ): + config = { + ("kubernetes", "multi_namespace_mode"): raw_multi_namespace_mode, + ("kubernetes", "multi_namespace_mode_namespace_list"): raw_value_namespace_list, + } + with conf_vars(config): + executor = KubernetesExecutor() + + assert executor.kube_config.multi_namespace_mode_namespace_list == expected_value_in_kube_config + @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") @mock.patch("airflow.executors.kubernetes_executor.AirflowKubernetesScheduler") @@ -735,7 +777,7 @@ def test_pending_pod_timeout(self, mock_kubescheduler, mock_get_kube_client, moc executor._check_worker_pods_pending_timeout() mock_kube_client.list_namespaced_pod.assert_called_once_with( - "mynamespace", + namespace="mynamespace", field_selector="status.phase=Pending", label_selector="airflow-worker=123", limit=5, @@ -783,6 +825,65 @@ def test_pending_pod_timeout_multi_namespace_mode( ) mock_delete_pod.assert_called_once_with("foo90", "anothernamespace") + @mock.patch("airflow.executors.kubernetes_executor.KubernetesJobWatcher") + @mock.patch("airflow.executors.kubernetes_executor.get_kube_client") + @mock.patch("airflow.executors.kubernetes_executor.AirflowKubernetesScheduler") + def test_pending_pod_timeout_multi_namespace_mode_limited_namespaces( + self, mock_kubescheduler, mock_get_kube_client, mock_kubernetes_job_watcher + ): + mock_delete_pod = mock_kubescheduler.return_value.delete_pod + mock_kube_client = mock_get_kube_client.return_value + now = timezone.utcnow() + pending_pods = [ + k8s.V1Pod( + metadata=k8s.V1ObjectMeta( + name="foo90", + labels={"airflow-worker": "123"}, + creation_timestamp=now - timedelta(seconds=500), + namespace="namespace-2", + ) + ), + ] + + def list_namespaced_pod(namespace, *args, **kwargs): + if namespace == "namespace-2": + return k8s.V1PodList(items=pending_pods) + else: + return k8s.V1PodList(items=[]) + + mock_kube_client.list_namespaced_pod.side_effect = list_namespaced_pod + + config = { + ("kubernetes", "namespace"): "mynamespace", + ("kubernetes", "multi_namespace_mode"): "true", + ("kubernetes", "multi_namespace_mode_namespace_list"): "namespace-1,namespace-2,namespace-3", + ("kubernetes", "kube_client_request_args"): '{"sentinel": "foo"}', + } + with conf_vars(config): + executor = KubernetesExecutor() + executor.job_id = "123" + executor.start() + executor._check_worker_pods_pending_timeout() + executor.end() + + assert mock_kube_client.list_namespaced_pod.call_count == 3 + mock_kube_client.list_namespaced_pod.assert_has_calls( + [ + mock.call( + namespace=namespace, + field_selector="status.phase=Pending", + label_selector="airflow-worker=123", + limit=100, + sentinel="foo", + ) + for namespace in ["namespace-1", "namespace-2", "namespace-3"] + ] + ) + + mock_delete_pod.assert_called_once_with("foo90", "namespace-2") + # mock_delete_pod should only be called once in total + mock_delete_pod.assert_called_once() + def test_clear_not_launched_queued_tasks_not_launched(self, dag_maker, create_dummy_dag, session): """If a pod isn't found for a TI, reset the state to scheduled""" mock_kube_client = mock.MagicMock() @@ -805,12 +906,12 @@ def test_clear_not_launched_queued_tasks_not_launched(self, dag_maker, create_du assert ti.state == State.SCHEDULED assert mock_kube_client.list_namespaced_pod.call_count == 2 mock_kube_client.list_namespaced_pod.assert_any_call( - "default", label_selector="dag_id=test_clear,task_id=task1,airflow-worker=1,run_id=test" + namespace="default", label_selector="dag_id=test_clear,task_id=task1,airflow-worker=1,run_id=test" ) # also check that we fall back to execution_date if we didn't find the pod with run_id execution_date_label = pod_generator.datetime_to_label_safe_datestring(ti.execution_date) mock_kube_client.list_namespaced_pod.assert_called_with( - "default", + namespace="default", label_selector=( f"dag_id=test_clear,task_id=task1,airflow-worker=1,execution_date={execution_date_label}" ), @@ -849,7 +950,7 @@ def test_clear_not_launched_queued_tasks_launched( ti.refresh_from_db() assert ti.state == State.QUEUED mock_kube_client.list_namespaced_pod.assert_called_once_with( - "default", label_selector="dag_id=test_clear,task_id=task1,airflow-worker=1,run_id=test" + namespace="default", label_selector="dag_id=test_clear,task_id=task1,airflow-worker=1,run_id=test" ) def test_clear_not_launched_queued_tasks_mapped_task(self, dag_maker, session): @@ -893,15 +994,15 @@ def list_namespaced_pod(*args, **kwargs): mock_kube_client.list_namespaced_pod.assert_has_calls( [ mock.call( - "default", + namespace="default", label_selector="dag_id=test_clear,task_id=bash,airflow-worker=1,map_index=0,run_id=test", ), mock.call( - "default", + namespace="default", label_selector="dag_id=test_clear,task_id=bash,airflow-worker=1,map_index=1,run_id=test", ), mock.call( - "default", + namespace="default", label_selector=f"dag_id=test_clear,task_id=bash,airflow-worker=1,map_index=1," f"execution_date={execution_date_label}", ), @@ -968,10 +1069,11 @@ def test_clear_not_launched_queued_tasks_clear_only_by_job_id(self, dag_maker, c class TestKubernetesJobWatcher: + test_namespace = "airflow" + def setup_method(self): self.watcher = KubernetesJobWatcher( - namespace="airflow", - multi_namespace_mode=False, + namespace=self.test_namespace, watcher_queue=mock.MagicMock(), resource_version="0", scheduler_job_id="123", @@ -1108,9 +1210,9 @@ def effect(): except Exception as e: assert e.args == ("sentinel",) - # both resource_version should be 0 after _run raises and exception + # both resource_version should be 0 after _run raises an exception assert self.watcher.resource_version == "0" - assert ResourceVersion().resource_version == "0" + assert ResourceVersion().resource_version == {self.test_namespace: "0"} # check that in the next run, _run is invoked with resource_version = 0 mock_underscore_run.reset_mock()