Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from airflow.configuration import conf
from airflow.providers.amazon.aws.executors.ecs.utils import (
CONFIG_GROUP_NAME,
ECS_LAUNCH_TYPE_EC2,
ECS_LAUNCH_TYPE_FARGATE,
AllEcsConfigKeys,
RunTaskKwargsConfigKeys,
camelize_dict_keys,
Expand All @@ -56,13 +58,15 @@ def _fetch_config_values() -> dict[str, str]:


def build_task_kwargs() -> dict:
all_config_keys = AllEcsConfigKeys()
# This will put some kwargs at the root of the dictionary that do NOT belong there. However,
# the code below expects them to be there and will rearrange them as necessary.
task_kwargs = _fetch_config_values()
task_kwargs.update(_fetch_templated_kwargs())

has_launch_type: bool = "launch_type" in task_kwargs
has_capacity_provider: bool = "capacity_provider_strategy" in task_kwargs
has_launch_type: bool = all_config_keys.LAUNCH_TYPE in task_kwargs
has_capacity_provider: bool = all_config_keys.CAPACITY_PROVIDER_STRATEGY in task_kwargs
is_launch_type_ec2: bool = task_kwargs.get(all_config_keys.LAUNCH_TYPE, None) == ECS_LAUNCH_TYPE_EC2

if has_capacity_provider and has_launch_type:
raise ValueError(
Expand All @@ -75,7 +79,12 @@ def build_task_kwargs() -> dict:
# the final fallback.
cluster = EcsHook().conn.describe_clusters(clusters=[task_kwargs["cluster"]])["clusters"][0]
if not cluster.get("defaultCapacityProviderStrategy"):
task_kwargs["launch_type"] = "FARGATE"
task_kwargs[all_config_keys.LAUNCH_TYPE] = ECS_LAUNCH_TYPE_FARGATE

# If you're using the EC2 launch type, you should not/can not provide the platform_version. In this
# case we'll drop it on the floor on behalf of the user, instead of throwing an exception.
if is_launch_type_ec2:
task_kwargs.pop(all_config_keys.PLATFORM_VERSION, None)

# There can only be 1 count of these containers
task_kwargs["count"] = 1 # type: ignore
Expand Down Expand Up @@ -105,7 +114,7 @@ def build_task_kwargs() -> dict:
"awsvpcConfiguration": {
"subnets": str(subnets).split(",") if subnets else None,
"securityGroups": str(security_groups).split(",") if security_groups else None,
"assignPublicIp": parse_assign_public_ip(assign_public_ip),
"assignPublicIp": parse_assign_public_ip(assign_public_ip, is_launch_type_ec2),
}
}
)
Expand Down
10 changes: 8 additions & 2 deletions airflow/providers/amazon/aws/executors/ecs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
ExecutorConfigFunctionType = Callable[[CommandType], dict]
ExecutorConfigType = Dict[str, Any]

ECS_LAUNCH_TYPE_EC2 = "EC2"
ECS_LAUNCH_TYPE_FARGATE = "FARGATE"

CONFIG_GROUP_NAME = "aws_ecs_executor"

CONFIG_DEFAULTS = {
Expand Down Expand Up @@ -247,9 +250,12 @@ def _recursive_flatten_dict(nested_dict):
return dict(items)


def parse_assign_public_ip(assign_public_ip):
def parse_assign_public_ip(assign_public_ip, is_launch_type_ec2=False):
"""Convert "assign_public_ip" from True/False to ENABLE/DISABLE."""
return "ENABLED" if assign_public_ip == "True" else "DISABLED"
# If the launch type is EC2, you cannot/should not provide the assignPublicIp parameter (which is
# specific to Fargate)
if not is_launch_type_ec2:
return "ENABLED" if assign_public_ip == "True" else "DISABLED"


def camelize_dict_keys(nested_dict) -> dict:
Expand Down