Skip to content

Commit

Permalink
Clearer code for PodGenerator.deserialize_model_file (#26641)
Browse files Browse the repository at this point in the history
Especially for how it handles non-existent file.

When the file path received doesn't exist, the current way is to use yaml.safe_load()
to process it, and it returns the path as a string.

Then this string is passed to deserialize_model_dict() and results in an empty object.
Passing 'None' to deserialize_model_dict() will do the same, but the code will become clearer, and less misleading.

Meanwhile, when the model file path received doesn't exist, there should be a warning in the log.

(This change shouldn't cause any behaviour change)
  • Loading branch information
XD-DENG authored Sep 26, 2022
1 parent 465c564 commit 35deda4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 36 deletions.
16 changes: 10 additions & 6 deletions airflow/kubernetes/pod_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import copy
import datetime
import hashlib
import logging
import os
import re
import uuid
Expand All @@ -40,6 +41,8 @@
from airflow.utils import yaml
from airflow.version import version as airflow_version

log = logging.getLogger(__name__)

MAX_LABEL_LEN = 63


Expand Down Expand Up @@ -412,24 +415,25 @@ def deserialize_model_file(path: str) -> k8s.V1Pod:
"""
:param path: Path to the file
:return: a kubernetes.client.models.V1Pod
Unfortunately we need access to the private method
``_ApiClient__deserialize_model`` from the kubernetes client.
This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.
"""
if os.path.exists(path):
with open(path) as stream:
pod = yaml.safe_load(stream)
else:
pod = yaml.safe_load(path)
pod = None
log.warning("Model file %s does not exist", path)

return PodGenerator.deserialize_model_dict(pod)

@staticmethod
def deserialize_model_dict(pod_dict: dict) -> k8s.V1Pod:
def deserialize_model_dict(pod_dict: dict | None) -> k8s.V1Pod:
"""
Deserializes python dictionary to k8s.V1Pod
Unfortunately we need access to the private method
``_ApiClient__deserialize_model`` from the kubernetes client.
This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.
:param pod_dict: Serialized dict of k8s.V1Pod object
:return: De-serialized k8s.V1Pod
"""
Expand Down
29 changes: 23 additions & 6 deletions tests/executors/test_kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import random
import re
import string
import sys
import unittest
from datetime import datetime, timedelta
from unittest import mock

import pytest
import yaml
from kubernetes.client import models as k8s
from kubernetes.client.rest import ApiException
from urllib3 import HTTPResponse
Expand Down Expand Up @@ -100,14 +102,33 @@ def test_create_pod_id(self):
@mock.patch("airflow.kubernetes.pod_generator.PodGenerator")
@mock.patch("airflow.executors.kubernetes_executor.KubeConfig")
def test_get_base_pod_from_template(self, mock_kubeconfig, mock_generator):
# Provide non-existent file path,
# so None will be passed to deserialize_model_dict().
pod_template_file_path = "/bar/biz"
get_base_pod_from_template(pod_template_file_path, None)
assert "deserialize_model_dict" == mock_generator.mock_calls[0][0]
assert pod_template_file_path == mock_generator.mock_calls[0][1][0]
assert mock_generator.mock_calls[0][1][0] is None

mock_kubeconfig.pod_template_file = "/foo/bar"
get_base_pod_from_template(None, mock_kubeconfig)
assert "deserialize_model_dict" == mock_generator.mock_calls[1][0]
assert "/foo/bar" == mock_generator.mock_calls[1][1][0]
assert mock_generator.mock_calls[1][1][0] is None

# Provide existent file path,
# so loaded YAML file content should be used to call deserialize_model_dict(), rather than None.
path = sys.path[0] + '/tests/kubernetes/pod.yaml'
with open(path) as stream:
expected_pod_dict = yaml.safe_load(stream)

pod_template_file_path = path
get_base_pod_from_template(pod_template_file_path, None)
assert "deserialize_model_dict" == mock_generator.mock_calls[2][0]
assert mock_generator.mock_calls[2][1][0] == expected_pod_dict

mock_kubeconfig.pod_template_file = path
get_base_pod_from_template(None, mock_kubeconfig)
assert "deserialize_model_dict" == mock_generator.mock_calls[3][0]
assert mock_generator.mock_calls[3][1][0] == expected_pod_dict

def test_make_safe_label_value(self):
for dag_id, task_id in self._cases():
Expand Down Expand Up @@ -228,8 +249,6 @@ def test_run_next_exception_requeue(
- 400 BadRequest is returned when your parameters are invalid e.g. asking for cpu=100ABC123.
"""
import sys

path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml'

response = HTTPResponse(body='{"message": "any message"}', status=status)
Expand Down Expand Up @@ -283,8 +302,6 @@ def test_run_next_pod_reconciliation_error(self, mock_get_kube_client, mock_kube
"""
When construct_pod raises PodReconciliationError, we should fail the task.
"""
import sys

path = sys.path[0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml'

mock_kube_client = mock.patch('kubernetes.client.CoreV1Api', autospec=True)
Expand Down
34 changes: 10 additions & 24 deletions tests/kubernetes/test_pod_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,11 +712,20 @@ def test_reconcile_specs_init_containers(self):
res = PodGenerator.reconcile_specs(base_spec, client_spec)
assert res.init_containers == base_spec.init_containers + client_spec.init_containers

def test_deserialize_model_file(self):
def test_deserialize_model_file(self, caplog):
path = sys.path[0] + '/tests/kubernetes/pod.yaml'
result = PodGenerator.deserialize_model_file(path)
sanitized_res = self.k8s_client.sanitize_for_serialization(result)
assert sanitized_res == self.deserialize_result
assert len(caplog.records) == 0

def test_deserialize_non_existent_model_file(self, caplog):
path = sys.path[0] + '/tests/kubernetes/non_existent.yaml'
result = PodGenerator.deserialize_model_file(path)
sanitized_res = self.k8s_client.sanitize_for_serialization(result)
assert sanitized_res == {}
assert len(caplog.records) == 1
assert 'does not exist' in caplog.text

@parameterized.expand(
(
Expand Down Expand Up @@ -761,29 +770,6 @@ def test_pod_name_is_valid(self, pod_id, expected_starts_with):

assert name.rsplit("-", 1)[0] == expected_starts_with

def test_deserialize_model_string(self):
fixture = """
apiVersion: v1
kind: Pod
metadata:
name: memory-demo
namespace: mem-example
spec:
containers:
- name: memory-demo-ctr
image: ghcr.io/apache/airflow-stress:1.0.4-2021.07.04
resources:
limits:
memory: "200Mi"
requests:
memory: "100Mi"
command: ["stress"]
args: ["--vm", "1", "--vm-bytes", "150M", "--vm-hang", "1"]
"""
result = PodGenerator.deserialize_model_file(fixture)
sanitized_res = self.k8s_client.sanitize_for_serialization(result)
assert sanitized_res == self.deserialize_result

def test_validate_pod_generator(self):
with pytest.raises(AirflowConfigException):
PodGenerator(pod=k8s.V1Pod(), pod_template_file='k')
Expand Down

0 comments on commit 35deda4

Please sign in to comment.