Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 60 additions & 30 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from google.api_core.exceptions import AlreadyExists
from google.cloud.container_v1.types import Cluster
Expand Down Expand Up @@ -268,41 +268,71 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval
self.deferrable = deferrable
self._check_input()
self._validate_input()

self._hook: GKEHook | None = None

def _check_input(self) -> None:
if (
not all([self.project_id, self.location, self.body])
or (isinstance(self.body, dict) and "name" not in self.body)
or (
isinstance(self.body, dict)
and ("initial_node_count" not in self.body and "node_pools" not in self.body)
)
or (not (isinstance(self.body, dict)) and not (getattr(self.body, "name", None)))
or (
not (isinstance(self.body, dict))
and (
not (getattr(self.body, "initial_node_count", None))
and not (getattr(self.body, "node_pools", None))
def _validate_input(self) -> None:
"""Primary validation of the input body."""
self._alert_deprecated_body_fields()

error_messages: list[str] = []
if not self._body_field("name"):
error_messages.append("Field body['name'] is missing or incorrect")

if self._body_field("initial_node_count"):
if self._body_field("node_pools"):
error_messages.append(
"Do not use filed body['initial_node_count'] and body['node_pools'] at the same time."
)
)
):
self.log.error(
"One of (project_id, location, body, body['name'], "
"body['initial_node_count']), body['node_pools'] is missing or incorrect"
)
raise AirflowException("Operator has incorrect or missing input.")
elif (
isinstance(self.body, dict) and ("initial_node_count" in self.body and "node_pools" in self.body)
) or (
not (isinstance(self.body, dict))
and (getattr(self.body, "initial_node_count", None) and getattr(self.body, "node_pools", None))
):
self.log.error("Only one of body['initial_node_count']) and body['node_pools'] may be specified")

if self._body_field("node_config"):
if self._body_field("node_pools"):
error_messages.append(
"Do not use filed body['node_config'] and body['node_pools'] at the same time."
)

if self._body_field("node_pools"):
if any([self._body_field("node_config"), self._body_field("initial_node_count")]):
error_messages.append(
"The field body['node_pools'] should not be set if "
"body['node_config'] or body['initial_code_count'] are specified."
)

if not any([self._body_field("node_config"), self._body_field("initial_node_count")]):
if not self._body_field("node_pools"):
error_messages.append(
"Field body['node_pools'] is required if none of fields "
"body['initial_node_count'] or body['node_pools'] are specified."
)

for message in error_messages:
self.log.error(message)

if error_messages:
raise AirflowException("Operator has incorrect or missing input.")

def _body_field(self, field_name: str, default_value: Any = None) -> Any:
"""Extracts the value of the given field name."""
if isinstance(self.body, dict):
return self.body.get(field_name, default_value)
else:
return getattr(self.body, field_name, default_value)

def _alert_deprecated_body_fields(self) -> None:
"""Generates warning messages if deprecated fields were used in the body."""
deprecated_body_fields_with_replacement = [
("initial_node_count", "node_pool.initial_node_count"),
("node_config", "node_pool.config"),
("zone", "location"),
("instance_group_urls", "node_pools.instance_group_urls"),
]
for deprecated_field, replacement in deprecated_body_fields_with_replacement:
if self._body_field(deprecated_field):
warnings.warn(
f"The body field '{deprecated_field}' is deprecated. Use '{replacement}' instead."
)

def execute(self, context: Context) -> str:
hook = self._get_hook()
try:
Expand Down