Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,18 +653,19 @@ def _get_bool(val) -> bool | None:
class AsyncKubernetesHook(KubernetesHook):
"""Hook to use Kubernetes SDK asynchronously."""

def __init__(self, *args, **kwargs):
def __init__(self, config_dict: dict | None = None, *args, **kwargs):
super().__init__(*args, **kwargs)

self.config_dict = config_dict
self._extras: dict | None = None

async def _load_config(self):
"""Return Kubernetes API session for use with requests."""
in_cluster = self._coalesce_param(self.in_cluster, await self._get_field("in_cluster"))
cluster_context = self._coalesce_param(self.cluster_context, await self._get_field("cluster_context"))
kubeconfig_path = self._coalesce_param(self.config_file, await self._get_field("kube_config_path"))
kubeconfig = await self._get_field("kube_config")

num_selected_configuration = sum(1 for o in [in_cluster, kubeconfig, kubeconfig_path] if o)
num_selected_configuration = sum(1 for o in [in_cluster, kubeconfig, self.config_dict] if o)

if num_selected_configuration > 1:
raise AirflowException(
Expand All @@ -679,14 +680,9 @@ async def _load_config(self):
async_config.load_incluster_config()
return async_client.ApiClient()

if kubeconfig_path:
self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("kube_config"))
self._is_in_cluster = False
await async_config.load_kube_config(
config_file=kubeconfig_path,
client_configuration=self.client_configuration,
context=cluster_context,
)
if self.config_dict:
self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("config dictionary"))
await async_config.load_kube_config_from_dict(self.config_dict)
return async_client.ApiClient()

if kubeconfig is not None:
Expand Down
13 changes: 12 additions & 1 deletion airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
import logging
import math
import os
import re
import shlex
import string
Expand Down Expand Up @@ -695,8 +696,18 @@ def execute_async(self, context: Context) -> None:
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)

self.convert_config_file_to_dict()
self.invoke_defer_method()

def convert_config_file_to_dict(self):
"""Convert passed config_file to dict representation."""
config_file = self.config_file if self.config_file else os.environ.get(KUBE_CONFIG_ENV_VAR)
if config_file:
with open(config_file) as f:
self._config_dict = yaml.safe_load(f)
else:
self._config_dict = None

def invoke_defer_method(self, last_log_time: DateTime | None = None) -> None:
"""Redefine triggers which are being used in child classes."""
trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc)
Expand All @@ -707,7 +718,7 @@ def invoke_defer_method(self, last_log_time: DateTime | None = None) -> None:
trigger_start_time=trigger_start_time,
kubernetes_conn_id=self.kubernetes_conn_id,
cluster_context=self.cluster_context,
config_file=self.config_file,
config_dict=self._config_dict,
in_cluster=self.in_cluster,
poll_interval=self.poll_interval,
get_logs=self.get_logs,
Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/cncf/kubernetes/triggers/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class KubernetesPodTrigger(BaseTrigger):
:param kubernetes_conn_id: The :ref:`kubernetes connection id <howto/connection:kubernetes>`
for the Kubernetes cluster.
:param cluster_context: Context that points to kubernetes cluster.
:param config_file: Path to kubeconfig file.
:param config_dict: Content of kubeconfig file in dict format.
:param poll_interval: Polling period in seconds to check for the status.
:param trigger_start_time: time in Datetime format when the trigger was started
:param in_cluster: run kubernetes client with in_cluster configuration.
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
kubernetes_conn_id: str | None = None,
poll_interval: float = 2,
cluster_context: str | None = None,
config_file: str | None = None,
config_dict: dict | None = None,
in_cluster: bool | None = None,
get_logs: bool = True,
startup_timeout: int = 120,
Expand All @@ -107,7 +107,7 @@ def __init__(
self.kubernetes_conn_id = kubernetes_conn_id
self.poll_interval = poll_interval
self.cluster_context = cluster_context
self.config_file = config_file
self.config_dict = config_dict
self.in_cluster = in_cluster
self.get_logs = get_logs
self.startup_timeout = startup_timeout
Expand Down Expand Up @@ -142,7 +142,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"kubernetes_conn_id": self.kubernetes_conn_id,
"poll_interval": self.poll_interval,
"cluster_context": self.cluster_context,
"config_file": self.config_file,
"config_dict": self.config_dict,
"in_cluster": self.in_cluster,
"get_logs": self.get_logs,
"startup_timeout": self.startup_timeout,
Expand Down Expand Up @@ -282,7 +282,7 @@ def _get_async_hook(self) -> AsyncKubernetesHook:
return AsyncKubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
config_dict=self.config_dict,
cluster_context=self.cluster_context,
)

Expand Down
8 changes: 5 additions & 3 deletions tests/providers/cncf/kubernetes/operators/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from contextlib import contextmanager, nullcontext
from io import BytesIO
from unittest import mock
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, mock_open, patch

import pendulum
import pytest
Expand Down Expand Up @@ -1866,7 +1866,7 @@ def run_pod_async(self, operator: KubernetesPodOperator, map_index: int = -1):
@patch(KUB_OP_PATH.format("build_pod_request_obj"))
@patch(KUB_OP_PATH.format("get_or_create_pod"))
def test_async_create_pod_should_execute_successfully(
self, mocked_pod, mocked_pod_obj, mocked_found_pod, mocked_client, do_xcom_push
self, mocked_pod, mocked_pod_obj, mocked_found_pod, mocked_client, do_xcom_push, mocker
):
"""
Asserts that a task is deferred and the KubernetesCreatePodTrigger will be fired
Expand All @@ -1889,7 +1889,9 @@ def test_async_create_pod_should_execute_successfully(
deferrable=True,
do_xcom_push=do_xcom_push,
)
k.config_file_in_dict_representation = {"a": "b"}

mock_file = mock_open(read_data='{"a": "b"}')
mocker.patch("builtins.open", mock_file)

mocked_pod.return_value.metadata.name = TEST_NAME
mocked_pod.return_value.metadata.namespace = TEST_NAMESPACE
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/cncf/kubernetes/triggers/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
CONN_ID = "test_kubernetes_conn_id"
POLL_INTERVAL = 2
CLUSTER_CONTEXT = "test-context"
CONFIG_FILE = "/path/to/config/file"
CONFIG_DICT = {"a": "b"}
IN_CLUSTER = False
GET_LOGS = True
STARTUP_TIMEOUT_SECS = 120
Expand All @@ -58,7 +58,7 @@ def trigger():
kubernetes_conn_id=CONN_ID,
poll_interval=POLL_INTERVAL,
cluster_context=CLUSTER_CONTEXT,
config_file=CONFIG_FILE,
config_dict=CONFIG_DICT,
in_cluster=IN_CLUSTER,
get_logs=GET_LOGS,
startup_timeout=STARTUP_TIMEOUT_SECS,
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_serialize(self, trigger):
"kubernetes_conn_id": CONN_ID,
"poll_interval": POLL_INTERVAL,
"cluster_context": CLUSTER_CONTEXT,
"config_file": CONFIG_FILE,
"config_dict": CONFIG_DICT,
"in_cluster": IN_CLUSTER,
"get_logs": GET_LOGS,
"startup_timeout": STARTUP_TIMEOUT_SECS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
import os
from unittest import mock
from unittest.mock import mock_open

import pytest
from google.cloud.container_v1.types import Cluster, NodePool
Expand Down Expand Up @@ -739,12 +740,15 @@ def setup_method(self):
)
@mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info")
def test_async_create_pod_should_execute_successfully(
self, fetch_cluster_info_mock, get_con_mock, mocked_pod, mocked_pod_obj
self, fetch_cluster_info_mock, get_con_mock, mocked_pod, mocked_pod_obj, mocker
):
"""
Asserts that a task is deferred and the GKEStartPodTrigger will be fired
when the GKEStartPodOperator is executed in deferrable mode when deferrable=True.
"""
mock_file = mock_open(read_data='{"a": "b"}')
mocker.patch("builtins.open", mock_file)

self.gke_op._cluster_url = CLUSTER_URL
self.gke_op._ssl_ca_cert = SSL_CA_CERT
with pytest.raises(TaskDeferred) as exc:
Expand Down