Skip to content

Commit 022d730

Browse files
committed
hdl._ast: add Format.Enum, Format.Struct, Format.Array.
1 parent c59447c commit 022d730

File tree

2 files changed

+206
-21
lines changed

2 files changed

+206
-21
lines changed

amaranth/hdl/_ast.py

Lines changed: 115 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2559,8 +2559,27 @@ def __repr__(self):
25592559
return "(initial)"
25602560

25612561

2562+
class _FormatLike:
2563+
def _as_format(self) -> "Format":
2564+
raise NotImplementedError # :nocov:
2565+
2566+
def __add__(self, other):
2567+
if not isinstance(other, _FormatLike):
2568+
return NotImplemented
2569+
return Format._from_chunks(self._as_format()._chunks + other._as_format()._chunks)
2570+
2571+
def __format__(self, format_desc):
2572+
"""Forbidden formatting.
2573+
2574+
``Format`` objects cannot be directly formatted for the same reason as the ``Value``s
2575+
they contain.
2576+
"""
2577+
raise TypeError(f"Format object {self!r} cannot be converted to string. Use `repr` "
2578+
f"to print the AST, or pass it to the `Print` statement.")
2579+
2580+
25622581
@final
2563-
class Format:
2582+
class Format(_FormatLike):
25642583
def __init__(self, format, *args, **kwargs):
25652584
fmt = string.Formatter()
25662585
chunks = []
@@ -2615,17 +2634,17 @@ def subformat(sub_string):
26152634
shape = obj.shape()
26162635
if isinstance(shape, ShapeCastable):
26172636
fmt = shape.format(obj, format_spec)
2618-
if not isinstance(fmt, Format):
2637+
if not isinstance(fmt, _FormatLike):
26192638
raise TypeError(f"`ShapeCastable.format` must return a 'Format' instance, not {fmt!r}")
2620-
chunks += fmt._chunks
2639+
chunks += fmt._as_format()._chunks
26212640
else:
26222641
obj = Value.cast(obj)
26232642
self._parse_format_spec(format_spec, obj.shape())
26242643
chunks.append((obj, format_spec))
2625-
elif isinstance(obj, Format):
2644+
elif isinstance(obj, _FormatLike):
26262645
if format_spec != "":
26272646
raise ValueError(f"Format specifiers ({format_spec!r}) cannot be used for 'Format' objects")
2628-
chunks += obj._chunks
2647+
chunks += obj._as_format()._chunks
26292648
else:
26302649
chunks.append(fmt.format_field(obj, format_spec))
26312650

@@ -2638,6 +2657,9 @@ def subformat(sub_string):
26382657

26392658
self._chunks = self._clean_chunks(chunks)
26402659

2660+
def _as_format(self):
2661+
return self
2662+
26412663
@classmethod
26422664
def _from_chunks(cls, chunks):
26432665
res = object.__new__(cls)
@@ -2671,25 +2693,11 @@ def _to_format_string(self):
26712693
format_string.append("{}")
26722694
return ("".join(format_string), tuple(args))
26732695

2674-
def __add__(self, other):
2675-
if not isinstance(other, Format):
2676-
return NotImplemented
2677-
return Format._from_chunks(self._chunks + other._chunks)
2678-
26792696
def __repr__(self):
26802697
format_string, args = self._to_format_string()
26812698
args = "".join(f" {arg!r}" for arg in args)
26822699
return f"(format {format_string!r}{args})"
26832700

2684-
def __format__(self, format_desc):
2685-
"""Forbidden formatting.
2686-
2687-
``Format`` objects cannot be directly formatted for the same reason as the ``Value``s
2688-
they contain.
2689-
"""
2690-
raise TypeError(f"Format object {self!r} cannot be converted to string. Use `repr` "
2691-
f"to print the AST, or pass it to the `Print` statement.")
2692-
26932701
_FORMAT_SPEC_PATTERN = re.compile(r"""
26942702
(?:
26952703
(?P<fill>.)?
@@ -2760,6 +2768,90 @@ def _rhs_signals(self):
27602768
return res
27612769

27622770

2771+
class Enum(_FormatLike):
2772+
def __init__(self, value, /, variants):
2773+
self._value = Value.cast(value)
2774+
if isinstance(variants, EnumMeta):
2775+
self._variants = {member.value: member.name for member in variants}
2776+
else:
2777+
self._variants = dict(variants)
2778+
for val, name in self._variants.items():
2779+
if not isinstance(val, int):
2780+
raise TypeError(f"Variant values must be integers, not {val!r}")
2781+
if not isinstance(name, str):
2782+
raise TypeError(f"Variant names must be strings, not {name!r}")
2783+
2784+
def _as_format(self):
2785+
def str_val(name):
2786+
name = name.encode()
2787+
return Const(int.from_bytes(name, "little"), len(name) * 8)
2788+
value = SwitchValue(self._value, [
2789+
(val, str_val(name))
2790+
for val, name in self._variants.items()
2791+
] + [(None, str_val("[unknown]"))])
2792+
return Format("{:s}", value)
2793+
2794+
def __repr__(self):
2795+
variants = "".join(
2796+
f" ({val!r} {name!r})"
2797+
for val, name in self._variants.items()
2798+
)
2799+
return f"(format-enum {self._value!r}{variants})"
2800+
2801+
2802+
class Struct(_FormatLike):
2803+
def __init__(self, value, /, fields):
2804+
self._value = Value.cast(value)
2805+
self._fields: dict[str, _FormatLike] = dict(fields)
2806+
for name, format in self._fields.items():
2807+
if not isinstance(name, str):
2808+
raise TypeError(f"Field names must be strings, not {name!r}")
2809+
if not isinstance(format, _FormatLike):
2810+
raise TypeError(f"Field format must be a 'Format', not {format!r}")
2811+
2812+
def _as_format(self):
2813+
chunks = ["{"]
2814+
for idx, (name, field) in enumerate(self._fields.items()):
2815+
if idx != 0:
2816+
chunks.append(", ")
2817+
chunks.append(f"{name}=")
2818+
chunks += field._as_format()._chunks
2819+
chunks.append("}")
2820+
return Format._from_chunks(chunks)
2821+
2822+
def __repr__(self):
2823+
fields = "".join(
2824+
f" ({name!r} {field!r})"
2825+
for name, field in self._fields.items()
2826+
)
2827+
return f"(format-struct {self._value!r}{fields})"
2828+
2829+
2830+
class Array(_FormatLike):
2831+
def __init__(self, value, /, fields):
2832+
self._value = Value.cast(value)
2833+
self._fields = list(fields)
2834+
for format in self._fields:
2835+
if not isinstance(format, (Format, Format.Enum, Format.Struct, Format.Array)):
2836+
raise TypeError(f"Field format must be a 'Format', not {format!r}")
2837+
2838+
def _as_format(self):
2839+
chunks = ["["]
2840+
for idx, field in enumerate(self._fields):
2841+
if idx != 0:
2842+
chunks.append(", ")
2843+
chunks += field._as_format()._chunks
2844+
chunks.append("]")
2845+
return Format._from_chunks(chunks)
2846+
2847+
def __repr__(self):
2848+
fields = "".join(
2849+
f" {field!r}"
2850+
for field in self._fields
2851+
)
2852+
return f"(format-array {self._value!r}{fields})"
2853+
2854+
27632855
class _StatementList(list):
27642856
def __repr__(self):
27652857
return "({})".format(" ".join(map(repr, self)))
@@ -2872,8 +2964,10 @@ def __init__(self, kind, test, message=None, *, src_loc_at=0):
28722964
self._test = Value.cast(test)
28732965
if isinstance(message, str):
28742966
message = Format._from_chunks([message])
2875-
if message is not None and not isinstance(message, Format):
2876-
raise TypeError(f"Property message must be None, str, or Format, not {message!r}")
2967+
if message is not None:
2968+
if not isinstance(message, _FormatLike):
2969+
raise TypeError(f"Property message must be None, str, or Format, not {message!r}")
2970+
message = message._as_format()
28772971
self._message = message
28782972
del self._MustUse__silence
28792973

tests/test_hdl_ast.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,6 +1736,97 @@ def test_format_wrong(self):
17361736
f"{fmt}"
17371737

17381738

1739+
class FormatEnumTestCase(FHDLTestCase):
1740+
def test_construct(self):
1741+
a = Signal(3)
1742+
fmt = Format.Enum(a, {1: "A", 2: "B", 3: "C"})
1743+
self.assertRepr(fmt, "(format-enum (sig a) (1 'A') (2 'B') (3 'C'))")
1744+
self.assertRepr(Format("{}", fmt), """
1745+
(format '{:s}' (switch-value (sig a)
1746+
(case 001 (const 8'd65))
1747+
(case 010 (const 8'd66))
1748+
(case 011 (const 8'd67))
1749+
(default (const 72'd1723507152241428428123))
1750+
))
1751+
""")
1752+
1753+
class MyEnum(Enum):
1754+
A = 0
1755+
B = 3
1756+
C = 4
1757+
1758+
fmt = Format.Enum(a, MyEnum)
1759+
self.assertRepr(fmt, "(format-enum (sig a) (0 'A') (3 'B') (4 'C'))")
1760+
self.assertRepr(Format("{}", fmt), """
1761+
(format '{:s}' (switch-value (sig a)
1762+
(case 000 (const 8'd65))
1763+
(case 011 (const 8'd66))
1764+
(case 100 (const 8'd67))
1765+
(default (const 72'd1723507152241428428123))
1766+
))
1767+
""")
1768+
1769+
def test_construct_wrong(self):
1770+
a = Signal(3)
1771+
with self.assertRaisesRegex(TypeError,
1772+
r"^Variant values must be integers, not 'a'$"):
1773+
Format.Enum(a, {"a": "B"})
1774+
with self.assertRaisesRegex(TypeError,
1775+
r"^Variant names must be strings, not 123$"):
1776+
Format.Enum(a, {1: 123})
1777+
1778+
1779+
class FormatStructTestCase(FHDLTestCase):
1780+
def test_construct(self):
1781+
sig = Signal(3)
1782+
fmt = Format.Struct(sig, {"a": Format("{}", sig[0]), "b": Format("{}", sig[1:3])})
1783+
self.assertRepr(fmt, """
1784+
(format-struct (sig sig)
1785+
('a' (format '{}' (slice (sig sig) 0:1)))
1786+
('b' (format '{}' (slice (sig sig) 1:3)))
1787+
)
1788+
""")
1789+
self.assertRepr(Format("{}", fmt), """
1790+
(format '{{a={}, b={}}}'
1791+
(slice (sig sig) 0:1)
1792+
(slice (sig sig) 1:3)
1793+
)
1794+
""")
1795+
1796+
def test_construct_wrong(self):
1797+
sig = Signal(3)
1798+
with self.assertRaisesRegex(TypeError,
1799+
r"^Field names must be strings, not 1$"):
1800+
Format.Struct(sig, {1: Format("{}", sig[1:3])})
1801+
with self.assertRaisesRegex(TypeError,
1802+
r"^Field format must be a 'Format', not \(slice \(sig sig\) 1:3\)$"):
1803+
Format.Struct(sig, {"a": sig[1:3]})
1804+
1805+
1806+
class FormatArrayTestCase(FHDLTestCase):
1807+
def test_construct(self):
1808+
sig = Signal(4)
1809+
fmt = Format.Array(sig, [Format("{}", sig[0:2]), Format("{}", sig[2:4])])
1810+
self.assertRepr(fmt, """
1811+
(format-array (sig sig)
1812+
(format '{}' (slice (sig sig) 0:2))
1813+
(format '{}' (slice (sig sig) 2:4))
1814+
)
1815+
""")
1816+
self.assertRepr(Format("{}", fmt), """
1817+
(format '[{}, {}]'
1818+
(slice (sig sig) 0:2)
1819+
(slice (sig sig) 2:4)
1820+
)
1821+
""")
1822+
1823+
def test_construct_wrong(self):
1824+
sig = Signal(3)
1825+
with self.assertRaisesRegex(TypeError,
1826+
r"^Field format must be a 'Format', not \(slice \(sig sig\) 1:3\)$"):
1827+
Format.Array(sig, [sig[1:3]])
1828+
1829+
17391830
class PrintTestCase(FHDLTestCase):
17401831
def test_construct(self):
17411832
a = Signal()

0 commit comments

Comments
 (0)