diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index b120f2590812f..c3a5f91bd65cb 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -25,6 +25,7 @@ import copy import datetime import hashlib +import logging import os import re import uuid @@ -40,6 +41,8 @@ from airflow.utils import yaml from airflow.version import version as airflow_version +log = logging.getLogger(__name__) + MAX_LABEL_LEN = 63 @@ -412,24 +415,25 @@ def deserialize_model_file(path: str) -> k8s.V1Pod: """ :param path: Path to the file :return: a kubernetes.client.models.V1Pod - - Unfortunately we need access to the private method - ``_ApiClient__deserialize_model`` from the kubernetes client. - This issue is tracked here; https://github.com/kubernetes-client/python/issues/977. """ if os.path.exists(path): with open(path) as stream: pod = yaml.safe_load(stream) else: - pod = yaml.safe_load(path) + pod = None + log.warning("Model file %s does not exist", path) return PodGenerator.deserialize_model_dict(pod) @staticmethod - def deserialize_model_dict(pod_dict: dict) -> k8s.V1Pod: + def deserialize_model_dict(pod_dict: dict | None) -> k8s.V1Pod: """ Deserializes python dictionary to k8s.V1Pod + Unfortunately we need access to the private method + ``_ApiClient__deserialize_model`` from the kubernetes client. + This issue is tracked here; https://github.com/kubernetes-client/python/issues/977. + :param pod_dict: Serialized dict of k8s.V1Pod object :return: De-serialized k8s.V1Pod """ diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 0d19c8237024b..574bd0afadcbf 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -20,11 +20,13 @@ import random import re import string +import sys import unittest from datetime import datetime, timedelta from unittest import mock import pytest +import yaml from kubernetes.client import models as k8s from kubernetes.client.rest import ApiException from urllib3 import HTTPResponse @@ -100,14 +102,33 @@ def test_create_pod_id(self): @mock.patch("airflow.kubernetes.pod_generator.PodGenerator") @mock.patch("airflow.executors.kubernetes_executor.KubeConfig") def test_get_base_pod_from_template(self, mock_kubeconfig, mock_generator): + # Provide non-existent file path, + # so None will be passed to deserialize_model_dict(). pod_template_file_path = "/bar/biz" get_base_pod_from_template(pod_template_file_path, None) assert "deserialize_model_dict" == mock_generator.mock_calls[0][0] - assert pod_template_file_path == mock_generator.mock_calls[0][1][0] + assert mock_generator.mock_calls[0][1][0] is None + mock_kubeconfig.pod_template_file = "/foo/bar" get_base_pod_from_template(None, mock_kubeconfig) assert "deserialize_model_dict" == mock_generator.mock_calls[1][0] - assert "/foo/bar" == mock_generator.mock_calls[1][1][0] + assert mock_generator.mock_calls[1][1][0] is None + + # Provide existent file path, + # so loaded YAML file content should be used to call deserialize_model_dict(), rather than None. + path = sys.path[0] + '/tests/kubernetes/pod.yaml' + with open(path) as stream: + expected_pod_dict = yaml.safe_load(stream) + + pod_template_file_path = path + get_base_pod_from_template(pod_template_file_path, None) + assert "deserialize_model_dict" == mock_generator.mock_calls[2][0] + assert mock_generator.mock_calls[2][1][0] == expected_pod_dict + + mock_kubeconfig.pod_template_file = path + get_base_pod_from_template(None, mock_kubeconfig) + assert "deserialize_model_dict" == mock_generator.mock_calls[3][0] + assert mock_generator.mock_calls[3][1][0] == expected_pod_dict def test_make_safe_label_value(self): for dag_id, task_id in self._cases(): @@ -228,8 +249,6 @@ def test_run_next_exception_requeue( - 400 BadRequest is returned when your parameters are invalid e.g. asking for cpu=100ABC123. """ - import sys - path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml' response = HTTPResponse(body='{"message": "any message"}', status=status) @@ -283,8 +302,6 @@ def test_run_next_pod_reconciliation_error(self, mock_get_kube_client, mock_kube """ When construct_pod raises PodReconciliationError, we should fail the task. """ - import sys - path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml' mock_kube_client = mock.patch('kubernetes.client.CoreV1Api', autospec=True) diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py index 8b53762fe241c..57039c04e8600 100644 --- a/tests/kubernetes/test_pod_generator.py +++ b/tests/kubernetes/test_pod_generator.py @@ -712,11 +712,20 @@ def test_reconcile_specs_init_containers(self): res = PodGenerator.reconcile_specs(base_spec, client_spec) assert res.init_containers == base_spec.init_containers + client_spec.init_containers - def test_deserialize_model_file(self): + def test_deserialize_model_file(self, caplog): path = sys.path[0] + '/tests/kubernetes/pod.yaml' result = PodGenerator.deserialize_model_file(path) sanitized_res = self.k8s_client.sanitize_for_serialization(result) assert sanitized_res == self.deserialize_result + assert len(caplog.records) == 0 + + def test_deserialize_non_existent_model_file(self, caplog): + path = sys.path[0] + '/tests/kubernetes/non_existent.yaml' + result = PodGenerator.deserialize_model_file(path) + sanitized_res = self.k8s_client.sanitize_for_serialization(result) + assert sanitized_res == {} + assert len(caplog.records) == 1 + assert 'does not exist' in caplog.text @parameterized.expand( ( @@ -761,29 +770,6 @@ def test_pod_name_is_valid(self, pod_id, expected_starts_with): assert name.rsplit("-", 1)[0] == expected_starts_with - def test_deserialize_model_string(self): - fixture = """ -apiVersion: v1 -kind: Pod -metadata: - name: memory-demo - namespace: mem-example -spec: - containers: - - name: memory-demo-ctr - image: ghcr.io/apache/airflow-stress:1.0.4-2021.07.04 - resources: - limits: - memory: "200Mi" - requests: - memory: "100Mi" - command: ["stress"] - args: ["--vm", "1", "--vm-bytes", "150M", "--vm-hang", "1"] - """ - result = PodGenerator.deserialize_model_file(fixture) - sanitized_res = self.k8s_client.sanitize_for_serialization(result) - assert sanitized_res == self.deserialize_result - def test_validate_pod_generator(self): with pytest.raises(AirflowConfigException): PodGenerator(pod=k8s.V1Pod(), pod_template_file='k')