Skip to content

Commit 60bc25e

Browse files
committed
feat(enums): Add AliasedStrEnum helper class and tests
1 parent 5d62a2e commit 60bc25e

File tree

2 files changed

+237
-0
lines changed

2 files changed

+237
-0
lines changed

src/dda/utils/enums.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# SPDX-FileCopyrightText: 2025-present Datadog, Inc. <dev@datadoghq.com>
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
from collections.abc import Iterator
6+
from enum import Enum, EnumMeta, StrEnum
7+
from typing import Self
8+
9+
10+
# Since __iter__ is defined at metaclass level for Enums, we cannot override it in AliasedStrEnum
11+
# So we need to create a new metaclass that overrides __iter__.
12+
class AliasedStrEnumMeta(EnumMeta):
13+
"""Metaclass for AliasedStrEnum that handles iteration correctly."""
14+
15+
def __iter__(cls) -> Iterator[Enum]: # type: ignore[override]
16+
"""Override default StrEnum iteration - needed because of __new__ override."""
17+
return iter(x for x in cls._member_map_.values())
18+
19+
20+
class AliasedStrEnum(StrEnum, metaclass=AliasedStrEnumMeta):
21+
"""
22+
A string enumeration that supports multiple alias values for each member.
23+
The first value is considered the canonical value.
24+
25+
When instantiating or comparing, any of the alias values will be considered equal to the canonical value.
26+
Example:
27+
class OS(AliasedStrEnum):
28+
LINUX = "linux"
29+
WINDOWS = ("windows", "nt", "win")
30+
31+
OS.WINDOWS == "windows" # True
32+
OS.WINDOWS == "nt" # True
33+
OS.WINDOWS == "linux" # False
34+
35+
OS("win") is OS.WINDOWS # True
36+
OS("linux") is OS.LINUX # True
37+
OS("bsd") # Raises ValueError
38+
39+
OS["WINDOWS"] # OS.WINDOWS
40+
OS["win"] # Raises KeyError - use OS("win") instead
41+
OS.WINDOWS.name # "WINDOWS"
42+
43+
str(OS.WINDOWS) # "windows"
44+
OS.WINDOWS.value # "windows"
45+
OS.WINDOWS.values # {"windows", "nt", "win"}
46+
OS.WINDOWS.aliases # {"nt", "win"}
47+
"""
48+
49+
# This method is called when registering new enum members, not when instantiating them.
50+
def __new__(cls, value: str, *alt_values: str) -> Self:
51+
# Create the StrEnum member instance
52+
obj = str.__new__(cls, value)
53+
obj._value_ = value
54+
obj._alt_values_ = set(alt_values) # type: ignore[attr-defined]
55+
56+
# Register in the reverse mapping for by-value lookup (e.g. Enum(value))
57+
cls._value2member_map_[value] = obj
58+
for alias in alt_values:
59+
cls._value2member_map_[alias] = obj
60+
return obj
61+
62+
@property
63+
def aliases(self) -> set[str]:
64+
"""Returns the set of non-canonical alias values for this enum member."""
65+
return self._alt_values_ # type: ignore[attr-defined]
66+
67+
@property
68+
def values(self) -> set[str]:
69+
"""Returns the set of all values (canonical + aliases) for this enum member."""
70+
return {self.value} | self._alt_values_ # type: ignore[attr-defined]
71+
72+
def __str__(self) -> str:
73+
return self.value
74+
75+
def __repr__(self) -> str:
76+
aliases_str = f" {tuple(sorted(self._alt_values_))!r}" if self._alt_values_ else "" # type: ignore[attr-defined]
77+
return f"<{self.__class__.__name__}.{self.name}: {self.value!r}{aliases_str}>"
78+
79+
def __eq__(self, value: object) -> bool:
80+
"""Override default StrEnum equality to consider equality with any of the alias values."""
81+
if isinstance(value, str):
82+
return value in self.values
83+
return super().__eq__(value)
84+
85+
def __contains__(self, item: object) -> bool:
86+
"""Override default StrEnum containment to consider containment with any of the alias values."""
87+
if isinstance(item, str):
88+
return item in self.values
89+
return False
90+
91+
def __hash__(self) -> int:
92+
"""Override default StrEnum hash to consider hash of any of the alias values."""
93+
return hash(self.value)

tests/utils/test_enums.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# SPDX-FileCopyrightText: 2024-present Datadog, Inc. <dev@datadoghq.com>
2+
#
3+
# SPDX-License-Identifier: MIT
4+
from __future__ import annotations
5+
6+
import pytest
7+
8+
from dda.utils.enums import AliasedStrEnum
9+
10+
11+
class ExampleEnum(AliasedStrEnum):
12+
"""Test enum for testing AliasedStrEnum functionality."""
13+
14+
LINUX = "linux"
15+
WINDOWS = ("windows", "nt", "win")
16+
MACOS = ("macos", "darwin", "osx")
17+
18+
19+
class TestAliasedStrEnum:
20+
def test_canonical_value_access(self):
21+
"""Test that canonical values work correctly."""
22+
assert ExampleEnum.LINUX == "linux"
23+
assert ExampleEnum.WINDOWS == "windows"
24+
assert ExampleEnum.MACOS == "macos"
25+
26+
def test_alias_equality(self):
27+
"""Test that alias values are equal to the enum member."""
28+
assert ExampleEnum.WINDOWS == "windows"
29+
assert ExampleEnum.WINDOWS == "nt"
30+
assert ExampleEnum.WINDOWS == "win"
31+
32+
assert ExampleEnum.MACOS == "macos"
33+
assert ExampleEnum.MACOS == "darwin"
34+
assert ExampleEnum.MACOS == "osx"
35+
36+
def test_alias_inequality(self):
37+
"""Test that non-alias values are not equal."""
38+
assert ExampleEnum.WINDOWS != "linux"
39+
assert ExampleEnum.LINUX != "windows"
40+
assert ExampleEnum.MACOS != "linux"
41+
42+
def test_instantiation_by_canonical_value(self):
43+
"""Test instantiation using canonical values."""
44+
assert ExampleEnum("linux") is ExampleEnum.LINUX
45+
assert ExampleEnum("windows") is ExampleEnum.WINDOWS
46+
assert ExampleEnum("macos") is ExampleEnum.MACOS
47+
48+
def test_instantiation_by_alias(self):
49+
"""Test instantiation using alias values."""
50+
assert ExampleEnum("nt") is ExampleEnum.WINDOWS
51+
assert ExampleEnum("win") is ExampleEnum.WINDOWS
52+
assert ExampleEnum("darwin") is ExampleEnum.MACOS
53+
assert ExampleEnum("osx") is ExampleEnum.MACOS
54+
55+
def test_instantiation_invalid_value(self):
56+
"""Test that invalid values raise ValueError."""
57+
with pytest.raises(match="'bsd' is not a valid TestOS"):
58+
ExampleEnum("bsd")
59+
with pytest.raises(match="'invalid' is not a valid TestOS"):
60+
ExampleEnum("invalid")
61+
62+
def test_name_access(self):
63+
"""Test that name property returns the enum member name."""
64+
assert ExampleEnum.LINUX.name == "LINUX"
65+
assert ExampleEnum.WINDOWS.name == "WINDOWS"
66+
assert ExampleEnum.MACOS.name == "MACOS"
67+
68+
def test_value_property(self):
69+
"""Test that value property returns the canonical value."""
70+
assert ExampleEnum.LINUX.value == "linux"
71+
assert ExampleEnum.WINDOWS.value == "windows"
72+
assert ExampleEnum.MACOS.value == "macos"
73+
74+
def test_aliases_property(self):
75+
"""Test that aliases property returns the set of non-canonical values."""
76+
assert ExampleEnum.LINUX.aliases == set()
77+
assert ExampleEnum.WINDOWS.aliases == {"nt", "win"}
78+
assert ExampleEnum.MACOS.aliases == {"darwin", "osx"}
79+
80+
def test_values_property(self):
81+
"""Test that values property returns all values (canonical + aliases)."""
82+
assert ExampleEnum.LINUX.values == {"linux"}
83+
assert ExampleEnum.WINDOWS.values == {"windows", "nt", "win"}
84+
assert ExampleEnum.MACOS.values == {"macos", "darwin", "osx"}
85+
86+
def test_str_representation(self):
87+
"""Test string representation returns canonical value."""
88+
assert str(ExampleEnum.LINUX) == "linux"
89+
assert str(ExampleEnum.WINDOWS) == "windows"
90+
assert str(ExampleEnum.MACOS) == "macos"
91+
92+
def test_repr_representation(self):
93+
"""Test repr representation includes aliases."""
94+
assert repr(ExampleEnum.LINUX) == "<TestOS.LINUX: 'linux'>"
95+
96+
# For members with aliases, check that aliases are included
97+
assert repr(ExampleEnum.WINDOWS) == "<TestOS.WINDOWS: 'windows' ('nt', 'win')>"
98+
assert repr(ExampleEnum.MACOS) == "<TestOS.MACOS: 'macos' ('darwin', 'osx')>"
99+
100+
def test_containment(self):
101+
"""Test containment operator with alias values."""
102+
assert "linux" in ExampleEnum.LINUX
103+
assert "windows" in ExampleEnum.WINDOWS
104+
assert "nt" in ExampleEnum.WINDOWS
105+
assert "win" in ExampleEnum.WINDOWS
106+
assert "darwin" in ExampleEnum.MACOS
107+
assert "osx" in ExampleEnum.MACOS
108+
109+
assert "bsd" not in ExampleEnum.LINUX
110+
assert "linux" not in ExampleEnum.WINDOWS
111+
112+
def test_hash_consistency(self):
113+
"""Test that hash is consistent for the same enum member."""
114+
assert hash(ExampleEnum.LINUX) == hash(ExampleEnum.LINUX)
115+
assert hash(ExampleEnum.WINDOWS) == hash(ExampleEnum.WINDOWS)
116+
117+
# Hash should be based on canonical value
118+
assert hash(ExampleEnum.WINDOWS) == hash("windows")
119+
120+
def test_member_access_by_name(self):
121+
"""Test accessing members by name using bracket notation."""
122+
assert ExampleEnum["LINUX"] is ExampleEnum.LINUX
123+
assert ExampleEnum["WINDOWS"] is ExampleEnum.WINDOWS
124+
assert ExampleEnum["MACOS"] is ExampleEnum.MACOS
125+
126+
def test_member_access_by_alias_fails(self):
127+
"""Test that accessing by alias value using bracket notation fails."""
128+
with pytest.raises(KeyError):
129+
ExampleEnum["nt"]
130+
with pytest.raises(KeyError):
131+
ExampleEnum["darwin"]
132+
133+
def test_iteration(self):
134+
"""Test that enum can be iterated over."""
135+
members = list(ExampleEnum)
136+
assert len(members) == 3
137+
assert ExampleEnum.LINUX in members
138+
assert ExampleEnum.WINDOWS in members
139+
assert ExampleEnum.MACOS in members
140+
141+
members2 = []
142+
for member in ExampleEnum:
143+
members2.append(member) # noqa: PERF402
144+
assert members == members2

0 commit comments

Comments
 (0)