Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions airflow/providers/amazon/aws/hooks/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import time
from collections import deque
from datetime import datetime, timedelta
from enum import Enum
from logging import Logger
from threading import Event, Thread
from typing import Generator
Expand All @@ -31,6 +30,7 @@
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.utils import _StringCompareEnum
from airflow.typing_compat import Protocol, runtime_checkable


Expand All @@ -55,23 +55,6 @@ def should_retry_eni(exception: Exception):
return False


class _StringCompareEnum(Enum):
"""
Enum which can be compared with regular `str` and subclasses.

This class avoids multiple inheritance such as AwesomeEnum(str, Enum)
which does not work well with templated_fields and Jinja templates.
"""

def __eq__(self, other):
if isinstance(other, str):
return self.value == other
return super().__eq__(other)

def __hash__(self):
return super().__hash__() # Need to set because we redefine __eq__


class EcsClusterStates(_StringCompareEnum):
"""Contains the possible State values of an ECS Cluster."""

Expand Down
18 changes: 18 additions & 0 deletions airflow/providers/amazon/aws/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import re
from datetime import datetime
from enum import Enum

from airflow.version import version

Expand All @@ -44,3 +45,20 @@ def datetime_to_epoch_us(date_time: datetime) -> int:
def get_airflow_version() -> tuple[int, ...]:
val = re.sub(r"(\d+\.\d+\.\d+).*", lambda x: x.group(1), version)
return tuple(int(x) for x in val.split("."))


class _StringCompareEnum(Enum):
"""
An Enum class which can be compared with regular `str` and subclasses.

This class avoids multiple inheritance such as AwesomeEnum(str, Enum)
which does not work well with templated_fields and Jinja templates.
"""

def __eq__(self, other):
if isinstance(other, str):
return self.value == other
return super().__eq__(other)

def __hash__(self):
return super().__hash__() # Need to set because we redefine __eq__
29 changes: 19 additions & 10 deletions tests/providers/amazon/aws/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytz

from airflow.providers.amazon.aws.utils import (
_StringCompareEnum,
datetime_to_epoch,
datetime_to_epoch_ms,
datetime_to_epoch_us,
Expand All @@ -32,16 +33,19 @@
EPOCH = 946_684_800


class TestUtils:
def test_trim_none_values(self):
input_object = {
"test": "test",
"empty": None,
}
expected_output_object = {
"test": "test",
}
assert trim_none_values(input_object) == expected_output_object
class EnumTest(_StringCompareEnum):
FOO = "FOO"


def test_trim_none_values():
input_object = {
"test": "test",
"empty": None,
}
expected_output_object = {
"test": "test",
}
assert trim_none_values(input_object) == expected_output_object


def test_datetime_to_epoch():
Expand All @@ -58,3 +62,8 @@ def test_datetime_to_epoch_us():

def test_get_airflow_version():
assert len(get_airflow_version()) == 3


def test_str_enum():
assert EnumTest.FOO == "FOO"
assert EnumTest.FOO.value == "FOO"