|
23 | 23 | from datetime import datetime, timezone |
24 | 24 | from pathlib import Path |
25 | 25 | from subprocess import check_call, check_output |
| 26 | +from typing import Literal |
26 | 27 |
|
27 | 28 | import pytest |
28 | 29 | import re2 |
@@ -62,7 +63,7 @@ def base_tests_setup(self, request): |
62 | 63 | self.session = self._get_session_with_retries() |
63 | 64 |
|
64 | 65 | # Ensure the api-server deployment is healthy at kubernetes level before calling the any API |
65 | | - self.ensure_deployment_health("airflow-webserver") |
| 66 | + self.ensure_resource_health("airflow-webserver") |
66 | 67 | try: |
67 | 68 | self._ensure_airflow_webserver_is_healthy() |
68 | 69 | yield |
@@ -195,12 +196,25 @@ def monitor_task(self, host, dag_run_id, dag_id, task_id, expected_final_state, |
195 | 196 | assert state == expected_final_state |
196 | 197 |
|
197 | 198 | @staticmethod |
198 | | - def ensure_deployment_health(deployment_name: str, namespace: str = "airflow"): |
199 | | - """Watch the deployment until it is healthy.""" |
200 | | - deployment_rollout_status = check_output( |
201 | | - ["kubectl", "rollout", "status", "deployment", deployment_name, "-n", namespace, "--watch"] |
| 199 | + def ensure_resource_health( |
| 200 | + resource_name: str, |
| 201 | + namespace: str = "airflow", |
| 202 | + resource_type: Literal["deployment", "statefulset"] = "deployment", |
| 203 | + ): |
| 204 | + """Watch the resource until it is healthy. |
| 205 | +
|
| 206 | + Args: |
| 207 | + resource_name (str): Name of the resource to check. |
| 208 | + resource_type (str): Type of the resource (e.g., deployment, statefulset). |
| 209 | + namespace (str): Kubernetes namespace where the resource is located. |
| 210 | + """ |
| 211 | + rollout_status = check_output( |
| 212 | + ["kubectl", "rollout", "status", f"{resource_type}/{resource_name}", "-n", namespace, "--watch"], |
202 | 213 | ).decode() |
203 | | - assert "successfully rolled out" in deployment_rollout_status |
| 214 | + if resource_type == "deployment": |
| 215 | + assert "successfully rolled out" in rollout_status |
| 216 | + else: |
| 217 | + assert "roll out complete" in rollout_status |
204 | 218 |
|
205 | 219 | def ensure_dag_expected_state(self, host, execution_date, dag_id, expected_final_state, timeout): |
206 | 220 | tries = 0 |
|
0 commit comments