Skip to content

hdl._nir, back.rtlil: use Format.* to emit enum attributes and wires for fields. #1323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions amaranth/back/rtlil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io

from ..utils import bits_for
from .._utils import to_binary
from ..lib import wiring
from ..hdl import _repr, _ast, _ir, _nir

Expand Down Expand Up @@ -421,6 +422,7 @@ def emit(self):
self.emit_cell_wires()
self.emit_submodule_wires()
self.emit_connects()
self.emit_signal_fields()
self.emit_submodules()
self.emit_cells()

Expand Down Expand Up @@ -491,12 +493,12 @@ def emit_signal_wires(self):
attrs.update(signal.attrs)
self.value_src_loc[value] = signal.src_loc

for repr in signal._value_repr:
if repr.path == () and isinstance(repr.format, _repr.FormatEnum):
enum = repr.format.enum
attrs["enum_base_type"] = enum.__name__
for enum_value in enum:
attrs["enum_value_{:0{}b}".format(enum_value.value, len(signal))] = enum_value.name
field = self.netlist.signal_fields[signal][()]
if field.enum_name is not None:
attrs["enum_base_type"] = field.enum_name
if field.enum_variants is not None:
for var_val, var_name in field.enum_variants.items():
attrs["enum_value_" + to_binary(var_val, len(signal))] = var_name

if name in self.module.ports:
port_value, _flow = self.module.ports[name]
Expand Down Expand Up @@ -666,6 +668,30 @@ def emit_connects(self):
if name not in self.driven_sigports:
self.builder.connect(wire.name, self.sigspec(value))

def emit_signal_fields(self):
for signal, name in self.module.signal_names.items():
fields = self.netlist.signal_fields[signal]
for path, field in fields.items():
if path == ():
continue
name_parts = [name]
for component in path:
if isinstance(component, str):
name_parts.append(f".{component}")
elif isinstance(component, int):
name_parts.append(f"[{component}]")
else:
assert False # :nocov:
attrs = {}
if field.enum_name is not None:
attrs["enum_base_type"] = field.enum_name
if field.enum_variants is not None:
for var_val, var_name in field.enum_variants.items():
attrs["enum_value_" + to_binary(var_val, len(field.value))] = var_name
wire = self.builder.wire(width=len(field.value), signed=field.signed, attrs=attrs,
name="".join(name_parts), src_loc=signal.src_loc)
self.builder.connect(wire.name, self.sigspec(field.value))

def emit_submodules(self):
for submodule_idx in self.module.submodules:
submodule = self.netlist.modules[submodule_idx]
Expand Down
5 changes: 5 additions & 0 deletions amaranth/hdl/_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,6 +2095,11 @@ def __init__(self, shape=None, *, name=None, init=None, reset=None, reset_less=F

self._attrs = OrderedDict(() if attrs is None else attrs)

if isinstance(orig_shape, ShapeCastable):
self._format = orig_shape.format(orig_shape(self), "")
else:
self._format = Format("{}", self)

if decoder is not None:
# The value representation is specified explicitly. Since we do not expose `hdl._repr`,
# this is the only way to add a custom filter to the signal right now.
Expand Down
41 changes: 39 additions & 2 deletions amaranth/hdl/_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def emit_value(self, builder):


class NetlistEmitter:
def __init__(self, netlist: _nir.Netlist, design, *, all_undef_to_ff=False):
def __init__(self, netlist: _nir.Netlist, design: Design, *, all_undef_to_ff=False):
self.netlist = netlist
self.design = design
self.all_undef_to_ff = all_undef_to_ff
Expand Down Expand Up @@ -776,7 +776,7 @@ def extend(self, value: _nir.Value, signed: bool, width: int):
def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src_loc):
op = _nir.Operator(module_idx, operator=operator, inputs=inputs, src_loc=src_loc)
return self.netlist.add_value_cell(op.width, op)

def emit_matches(self, module_idx: int, value: _nir.Value, patterns, *, src_loc):
key = module_idx, value, patterns, src_loc
try:
Expand Down Expand Up @@ -1334,6 +1334,42 @@ def emit_top_ports(self, fragment: _ir.Fragment):
else:
raise ValueError(f"Invalid port direction {dir!r}")

def emit_signal_fields(self):
for signal, fragment in self.design.signal_lca.items():
module_idx = self.fragment_module_idx[fragment]
fields = {}
def emit_format(path, fmt):
if isinstance(fmt, _ast.Format):
specs = [
chunk[0]
for chunk in fmt._chunks
if not isinstance(chunk, str)
]
if len(specs) != 1:
return
val, signed = self.emit_rhs(module_idx, specs[0])
fields[path] = _nir.SignalField(val, signed=signed)
elif isinstance(fmt, _ast.Format.Enum):
val, signed = self.emit_rhs(module_idx, fmt._value)
fields[path] = _nir.SignalField(val, signed=signed,
enum_name=fmt._name,
enum_variants=fmt._variants)
elif isinstance(fmt, _ast.Format.Struct):
val, signed = self.emit_rhs(module_idx, fmt._value)
fields[path] = _nir.SignalField(val, signed=signed)
for name, subfmt in fmt._fields.items():
emit_format(path + (name,), subfmt)
elif isinstance(fmt, _ast.Format.Array):
val, signed = self.emit_rhs(module_idx, fmt._value)
fields[path] = _nir.SignalField(val, signed=signed)
for idx, subfmt in enumerate(fmt._fields):
emit_format(path + (idx,), subfmt)
emit_format((), signal._format)
val, signed = self.emit_rhs(module_idx, signal)
if () not in fields or fields[()].value != val:
fields[()] = _nir.SignalField(val, signed=signed)
self.netlist.signal_fields[signal] = fields

def emit_drivers(self):
for driver in self.drivers.values():
if (driver.domain is not None and
Expand Down Expand Up @@ -1452,6 +1488,7 @@ def emit_fragment(self, fragment: _ir.Fragment, parent_module_idx: 'int | None',
for subfragment, _name, sub_src_loc in fragment.subfragments:
self.emit_fragment(subfragment, module_idx, cell_src_loc=sub_src_loc)
if parent_module_idx is None:
self.emit_signal_fields()
self.emit_drivers()
self.emit_top_ports(fragment)
if self.all_undef_to_ff:
Expand Down
21 changes: 21 additions & 0 deletions amaranth/hdl/_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,22 @@ def resolve_nets(self, netlist: "Netlist"):
chunk.value = netlist.resolve_value(chunk.value)


class SignalField:
"""Describes a single field of a signal."""
def __init__(self, value, *, signed, enum_name=None, enum_variants=None):
self.value = Value(value)
self.signed = bool(signed)
self.enum_name = enum_name
self.enum_variants = enum_variants

def __eq__(self, other):
return (type(self) is type(other) and
self.value == other.value and
self.signed == other.signed and
self.enum_name == other.enum_name and
self.enum_variants == other.enum_variants)


class Netlist:
"""A fine netlist. Consists of:

Expand Down Expand Up @@ -321,6 +337,7 @@ class Netlist:
connections : dict of (negative) int to int
io_ports : list of ``IOPort``
signals : dict of Signal to ``Value``
signal_fields: dict of Signal to dict of tuple[str | int] to SignalField
last_late_net: int
"""
def __init__(self):
Expand All @@ -329,6 +346,7 @@ def __init__(self):
self.connections: dict[Net, Net] = {}
self.io_ports: list[_ast.IOPort] = []
self.signals = SignalDict()
self.signal_fields = SignalDict()
self.last_late_net = 0

def resolve_net(self, net: Net):
Expand All @@ -345,6 +363,9 @@ def resolve_all_nets(self):
cell.resolve_nets(self)
for sig in self.signals:
self.signals[sig] = self.resolve_value(self.signals[sig])
for fields in self.signal_fields.values():
for field in fields.values():
field.value = self.resolve_value(field.value)

def __repr__(self):
result = ["("]
Expand Down
63 changes: 62 additions & 1 deletion tests/test_back_rtlil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from amaranth.back import rtlil
from amaranth.hdl import *
from amaranth.hdl._ast import *
from amaranth.lib import memory, wiring
from amaranth.lib import memory, wiring, data, enum

from .utils import *

Expand Down Expand Up @@ -2010,6 +2010,67 @@ def test_print_align(self):
""")


class DetailTestCase(RTLILTestCase):
def test_enum(self):
class MyEnum(enum.Enum, shape=unsigned(2)):
A = 0
B = 1
C = 2

sig = Signal(MyEnum)
m = Module()
m.d.comb += sig.eq(MyEnum.A)
self.assertRTLIL(m, [sig.as_value()], R"""
attribute \generator "Amaranth"
attribute \top 1
module \top
attribute \enum_base_type "MyEnum"
attribute \enum_value_00 "A"
attribute \enum_value_01 "B"
attribute \enum_value_10 "C"
wire width 2 output 0 \sig
connect \sig 2'00
end
""")

def test_struct(self):
class MyEnum(enum.Enum, shape=unsigned(2)):
A = 0
B = 1
C = 2

class Meow(data.Struct):
a: MyEnum
b: 3
c: signed(4)
d: data.ArrayLayout(2, 2)

sig = Signal(Meow)
m = Module()
self.assertRTLIL(m, [sig.as_value()], R"""
attribute \generator "Amaranth"
attribute \top 1
module \top
wire width 13 input 0 \sig
attribute \enum_base_type "MyEnum"
attribute \enum_value_00 "A"
attribute \enum_value_01 "B"
attribute \enum_value_10 "C"
wire width 2 \sig.a
wire width 3 \sig.b
wire width 4 signed \sig.c
wire width 4 \sig.d
wire width 2 \sig.d[0]
wire width 2 \sig.d[1]
connect \sig.a \sig [1:0]
connect \sig.b \sig [4:2]
connect \sig.c \sig [8:5]
connect \sig.d \sig [12:9]
connect \sig.d[0] \sig [10:9]
connect \sig.d[1] \sig [12:11]
end
""")

class ComponentTestCase(RTLILTestCase):
def test_component(self):
class MyComponent(wiring.Component):
Expand Down
41 changes: 41 additions & 0 deletions tests/test_hdl_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from amaranth.hdl._dsl import *
from amaranth.hdl._ir import *
from amaranth.hdl._mem import *
from amaranth.hdl._nir import SignalField

from amaranth.lib import enum, data

from .utils import *

Expand Down Expand Up @@ -3501,3 +3504,41 @@ def test_undef_to_ff_partial(self):
(cell 3 0 (flipflop 3.0:5 10 pos 0 0))
)
""")


class FieldsTestCase(FHDLTestCase):
def test_fields(self):
class MyEnum(enum.Enum, shape=unsigned(2)):
A = 0
B = 1
C = 2
l = data.StructLayout({"a": MyEnum, "b": signed(3)})
s1 = Signal(l)
s2 = Signal(MyEnum)
s3 = Signal(signed(3))
s4 = Signal(unsigned(4))
nl = build_netlist(Fragment.get(Module(), None), [
s1.as_value(), s2.as_value(), s3, s4,
])
self.assertEqual(nl.signal_fields[s1.as_value()], {
(): SignalField(nl.signals[s1.as_value()], signed=False),
('a',): SignalField(nl.signals[s1.as_value()][0:2], signed=False, enum_name="MyEnum", enum_variants={
0: "A",
1: "B",
2: "C",
}),
('b',): SignalField(nl.signals[s1.as_value()][2:5], signed=True)
})
self.assertEqual(nl.signal_fields[s2.as_value()], {
(): SignalField(nl.signals[s2.as_value()], signed=False, enum_name="MyEnum", enum_variants={
0: "A",
1: "B",
2: "C",
}),
})
self.assertEqual(nl.signal_fields[s3], {
(): SignalField(nl.signals[s3], signed=True),
})
self.assertEqual(nl.signal_fields[s4], {
(): SignalField(nl.signals[s4], signed=False),
})
3 changes: 3 additions & 0 deletions tests/test_lib_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,9 @@ def const(self, init):
def from_bits(self, bits):
return bits

def format(self, value, spec):
return Format("")

v = Signal(data.StructLayout({
"f": WrongCastable()
}))
Expand Down