Skip to content
249 changes: 161 additions & 88 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
"""This module contains Amazon EKS operators."""
from __future__ import annotations

import logging
import warnings
from ast import literal_eval
from datetime import timedelta
from typing import TYPE_CHECKING, Any, List, Sequence, cast
from typing import TYPE_CHECKING, List, Sequence, cast

from botocore.exceptions import ClientError, WaiterError

Expand All @@ -31,6 +32,7 @@
EksCreateFargateProfileTrigger,
EksDeleteFargateProfileTrigger,
)
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait

try:
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
Expand Down Expand Up @@ -59,6 +61,75 @@
FARGATE_FULL_NAME = "AWS Fargate profiles"


def _create_compute(
compute: str | None,
cluster_name: str,
aws_conn_id: str,
region: str | None,
waiter_delay: int,
waiter_max_attempts: int,
wait_for_completion: bool = False,
nodegroup_name: str | None = None,
nodegroup_role_arn: str | None = None,
create_nodegroup_kwargs: dict | None = None,
fargate_profile_name: str | None = None,
fargate_pod_execution_role_arn: str | None = None,
fargate_selectors: list | None = None,
create_fargate_profile_kwargs: dict | None = None,
subnets: list[str] | None = None,
):
log = logging.getLogger(__name__)
eks_hook = EksHook(aws_conn_id=aws_conn_id, region_name=region)
if compute == "nodegroup" and nodegroup_name:

# this is to satisfy mypy
subnets = subnets or []
create_nodegroup_kwargs = create_nodegroup_kwargs or {}

eks_hook.create_nodegroup(
clusterName=cluster_name,
nodegroupName=nodegroup_name,
subnets=subnets,
nodeRole=nodegroup_role_arn,
**create_nodegroup_kwargs,
)
if wait_for_completion:
log.info("Waiting for nodegroup to provision. This will take some time.")
wait(
waiter=eks_hook.conn.get_waiter("nodegroup_active"),
waiter_delay=waiter_delay,
max_attempts=waiter_max_attempts,
args={"clusterName": cluster_name, "nodegroupName": nodegroup_name},
failure_message="Nodegroup creation failed",
status_message="Nodegroup status is",
status_args=["nodegroup.status"],
)
elif compute == "fargate" and fargate_profile_name:

# this is to satisfy mypy
create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
fargate_selectors = fargate_selectors or []

eks_hook.create_fargate_profile(
clusterName=cluster_name,
fargateProfileName=fargate_profile_name,
podExecutionRoleArn=fargate_pod_execution_role_arn,
selectors=fargate_selectors,
**create_fargate_profile_kwargs,
)
if wait_for_completion:
log.info("Waiting for Fargate profile to provision. This will take some time.")
wait(
waiter=eks_hook.conn.get_waiter("fargate_profile_active"),
waiter_delay=waiter_delay,
max_attempts=waiter_max_attempts,
args={"clusterName": cluster_name, "fargateProfileName": fargate_profile_name},
failure_message="Fargate profile creation failed",
status_message="Fargate profile status is",
status_args=["fargateProfile.status"],
)


class EksCreateClusterOperator(BaseOperator):
"""
Creates an Amazon EKS Cluster control plane.
Expand Down Expand Up @@ -112,6 +183,8 @@ class EksCreateClusterOperator(BaseOperator):
:param fargate_selectors: The selectors to match for pods to use this AWS Fargate profile. (templated)
:param create_fargate_profile_kwargs: Optional parameters to pass to the CreateFargateProfile API
(templated)
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster status
:param waiter_max_attempts: The maximum number of attempts to check the status of the cluster.

"""

Expand All @@ -137,7 +210,7 @@ def __init__(
self,
cluster_name: str,
cluster_role_arn: str,
resources_vpc_config: dict[str, Any],
resources_vpc_config: dict,
compute: str | None = DEFAULT_COMPUTE_TYPE,
create_cluster_kwargs: dict | None = None,
nodegroup_name: str = DEFAULT_NODEGROUP_NAME,
Expand All @@ -150,24 +223,30 @@ def __init__(
wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 40,
**kwargs,
) -> None:
self.compute = compute
self.cluster_name = cluster_name
self.cluster_role_arn = cluster_role_arn
self.resources_vpc_config = resources_vpc_config
self.create_cluster_kwargs = create_cluster_kwargs or {}
self.nodegroup_name = nodegroup_name
self.nodegroup_role_arn = nodegroup_role_arn
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
self.fargate_profile_name = fargate_profile_name
self.fargate_pod_execution_role_arn = fargate_pod_execution_role_arn
self.fargate_selectors = fargate_selectors or [{"namespace": DEFAULT_NAMESPACE_NAME}]
self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)
self.nodegroup_name = nodegroup_name
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
self.fargate_selectors = fargate_selectors or [{"namespace": DEFAULT_NAMESPACE_NAME}]
self.fargate_profile_name = fargate_profile_name
super().__init__(
**kwargs,
)

def execute(self, context: Context):
if self.compute:
Expand All @@ -183,13 +262,8 @@ def execute(self, context: Context):
compute=FARGATE_FULL_NAME, requirement="fargate_pod_execution_role_arn"
)
)

eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)

eks_hook.create_cluster(
self.eks_hook = EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
self.eks_hook.create_cluster(
name=self.cluster_name,
roleArn=self.cluster_role_arn,
resourcesVpcConfig=self.resources_vpc_config,
Expand All @@ -202,44 +276,38 @@ def execute(self, context: Context):
return None

self.log.info("Waiting for EKS Cluster to provision. This will take some time.")
client = eks_hook.conn
client = self.eks_hook.conn

try:
client.get_waiter("cluster_active").wait(name=self.cluster_name)
client.get_waiter("cluster_active").wait(
name=self.cluster_name,
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)
except (ClientError, WaiterError) as e:
self.log.error("Cluster failed to start and will be torn down.\n %s", e)
eks_hook.delete_cluster(name=self.cluster_name)
client.get_waiter("cluster_deleted").wait(name=self.cluster_name)
raise

if self.compute == "nodegroup":
eks_hook.create_nodegroup(
clusterName=self.cluster_name,
nodegroupName=self.nodegroup_name,
subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")),
nodeRole=self.nodegroup_role_arn,
**self.create_nodegroup_kwargs,
)
if self.wait_for_completion:
self.log.info("Waiting for nodegroup to provision. This will take some time.")
client.get_waiter("nodegroup_active").wait(
clusterName=self.cluster_name,
nodegroupName=self.nodegroup_name,
)
elif self.compute == "fargate":
eks_hook.create_fargate_profile(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
podExecutionRoleArn=self.fargate_pod_execution_role_arn,
selectors=self.fargate_selectors,
**self.create_fargate_profile_kwargs,
self.eks_hook.delete_cluster(name=self.cluster_name)
client.get_waiter("cluster_deleted").wait(
name=self.cluster_name,
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)
if self.wait_for_completion:
self.log.info("Waiting for Fargate profile to provision. This will take some time.")
client.get_waiter("fargate_profile_active").wait(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
)
raise
_create_compute(
compute=self.compute,
cluster_name=self.cluster_name,
aws_conn_id=self.aws_conn_id,
region=self.region,
wait_for_completion=self.wait_for_completion,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
nodegroup_name=self.nodegroup_name,
nodegroup_role_arn=self.nodegroup_role_arn,
create_nodegroup_kwargs=self.create_nodegroup_kwargs,
fargate_profile_name=self.fargate_profile_name,
fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
fargate_selectors=self.fargate_selectors,
create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")),
)


class EksCreateNodegroupOperator(BaseOperator):
Expand All @@ -265,6 +333,8 @@ class EksCreateNodegroupOperator(BaseOperator):
maintained on each worker node).
:param region: Which AWS region the connection should use. (templated)
If this is None or empty then the default boto3 behaviour is used.
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check nodegroup status
:param waiter_max_attempts: The maximum number of attempts to check the status of the nodegroup.

"""

Expand All @@ -289,19 +359,28 @@ def __init__(
wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 80,
**kwargs,
) -> None:
self.nodegroup_subnets = nodegroup_subnets
self.compute = "nodegroup"
self.cluster_name = cluster_name
self.nodegroup_role_arn = nodegroup_role_arn
self.nodegroup_name = nodegroup_name
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
self.nodegroup_subnets = nodegroup_subnets
super().__init__(**kwargs)
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts

super().__init__(
**kwargs,
)

def execute(self, context: Context):
self.log.info(self.task_id)
if isinstance(self.nodegroup_subnets, str):
nodegroup_subnets_list: list[str] = []
if self.nodegroup_subnets != "":
Expand All @@ -314,25 +393,20 @@ def execute(self, context: Context):
self.nodegroup_subnets,
)
self.nodegroup_subnets = nodegroup_subnets_list

eks_hook = EksHook(
_create_compute(
compute=self.compute,
cluster_name=self.cluster_name,
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)
eks_hook.create_nodegroup(
clusterName=self.cluster_name,
nodegroupName=self.nodegroup_name,
region=self.region,
wait_for_completion=self.wait_for_completion,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
nodegroup_name=self.nodegroup_name,
nodegroup_role_arn=self.nodegroup_role_arn,
create_nodegroup_kwargs=self.create_nodegroup_kwargs,
subnets=self.nodegroup_subnets,
nodeRole=self.nodegroup_role_arn,
**self.create_nodegroup_kwargs,
)

if self.wait_for_completion:
self.log.info("Waiting for nodegroup to provision. This will take some time.")
eks_hook.conn.get_waiter("nodegroup_active").wait(
clusterName=self.cluster_name, nodegroupName=self.nodegroup_name
)


class EksCreateFargateProfileOperator(BaseOperator):
"""
Expand Down Expand Up @@ -392,52 +466,50 @@ def __init__(
**kwargs,
) -> None:
self.cluster_name = cluster_name
self.pod_execution_role_arn = pod_execution_role_arn
self.selectors = selectors
self.pod_execution_role_arn = pod_execution_role_arn
self.fargate_profile_name = fargate_profile_name
self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
self.wait_for_completion = wait_for_completion
self.wait_for_completion = False if deferrable else wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
super().__init__(**kwargs)
self.compute = "fargate"
super().__init__(
**kwargs,
)

def execute(self, context: Context):
eks_hook = EksHook(
_create_compute(
compute=self.compute,
cluster_name=self.cluster_name,
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)

eks_hook.create_fargate_profile(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
podExecutionRoleArn=self.pod_execution_role_arn,
selectors=self.selectors,
**self.create_fargate_profile_kwargs,
region=self.region,
wait_for_completion=self.wait_for_completion,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
fargate_profile_name=self.fargate_profile_name,
fargate_pod_execution_role_arn=self.pod_execution_role_arn,
fargate_selectors=self.selectors,
create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
)
if self.deferrable:
self.defer(
trigger=EksCreateFargateProfileTrigger(
cluster_name=self.cluster_name,
fargate_profile_name=self.fargate_profile_name,
aws_conn_id=self.aws_conn_id,
poll_interval=self.waiter_delay,
max_attempts=self.waiter_max_attempts,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
region=self.region,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=(self.waiter_max_attempts * self.waiter_delay + 60)),
)
elif self.wait_for_completion:
self.log.info("Waiting for Fargate profile to provision. This will take some time.")
eks_hook.conn.get_waiter("fargate_profile_active").wait(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)

def execute_complete(self, context, event=None):
if event["status"] != "success":
Expand Down Expand Up @@ -677,8 +749,9 @@ def execute(self, context: Context):
cluster_name=self.cluster_name,
fargate_profile_name=self.fargate_profile_name,
aws_conn_id=self.aws_conn_id,
poll_interval=self.waiter_delay,
max_attempts=self.waiter_max_attempts,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
region=self.region,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down
Loading