Skip to content

hdl._nir: add combinational cycle detection. #1330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions amaranth/hdl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment
from ._ir import Instance, IOBufferInstance
from ._mem import FrozenMemory, MemoryData, MemoryInstance, Memory, ReadPort, WritePort, DummyPort
from ._nir import CombinationalCycle
from ._rec import Record
from ._xfrm import DomainRenamer, ResetInserter, EnableInserter

Expand All @@ -28,6 +29,8 @@
# _ir
"UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment",
"Instance", "IOBufferInstance",
# _nir
"CombinationalCycle",
# _mem
"FrozenMemory", "MemoryData", "MemoryInstance", "Memory", "ReadPort", "WritePort", "DummyPort",
# _rec
Expand Down
3 changes: 2 additions & 1 deletion amaranth/hdl/_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def __init__(self, netlist: _nir.Netlist, design: Design, *, all_undef_to_ff=Fal
def emit_signal(self, signal) -> _nir.Value:
if signal in self.netlist.signals:
return self.netlist.signals[signal]
value = self.netlist.alloc_late_value(len(signal))
value = self.netlist.alloc_late_value(signal)
self.netlist.signals[signal] = value
for bit, net in enumerate(value):
self.late_net_to_signal[net] = (signal, bit)
Expand Down Expand Up @@ -1738,6 +1738,7 @@ def build_netlist(fragment, ports=(), *, name="top", all_undef_to_ff=False, **kw
design = fragment.prepare(ports=ports, hierarchy=(name,), **kwargs)
netlist = _nir.Netlist()
_emit_netlist(netlist, design, all_undef_to_ff=all_undef_to_ff)
netlist.check_comb_cycles()
netlist.resolve_all_nets()
_compute_net_flows(netlist)
_compute_ports(netlist)
Expand Down
152 changes: 147 additions & 5 deletions amaranth/hdl/_nir.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable
from typing import Iterable, Any
import enum

from ._ast import SignalDict
Expand All @@ -7,8 +7,9 @@

__all__ = [
# Netlist core
"CombinationalCycle",
"Net", "Value", "IONet", "IOValue",
"FormatValue", "Format",
"FormatValue", "Format", "SignalField",
"Netlist", "ModuleNetFlow", "IODirection", "Module", "Cell", "Top",
# Computation cells
"Operator", "Part",
Expand All @@ -25,6 +26,10 @@
]


class CombinationalCycle(Exception):
pass


class Net(int):
__slots__ = ()

Expand Down Expand Up @@ -335,6 +340,7 @@ class Netlist:
modules : list of ``Module``
cells : list of ``Cell``
connections : dict of (negative) int to int
late_to_signal : dict of (late) Net to its Signal and bit number
io_ports : list of ``IOPort``
signals : dict of Signal to ``Value``
signal_fields: dict of Signal to dict of tuple[str | int] to SignalField
Expand All @@ -344,6 +350,7 @@ def __init__(self):
self.modules: list[Module] = []
self.cells: list[Cell] = [Top()]
self.connections: dict[Net, Net] = {}
self.late_to_signal: dict[Net, (_ast.Signal, int)] = {}
self.io_ports: list[_ast.IOPort] = []
self.signals = SignalDict()
self.signal_fields = SignalDict()
Expand Down Expand Up @@ -405,16 +412,75 @@ def add_value_cell(self, width: int, cell):
cell_idx = self.add_cell(cell)
return Value(Net.from_cell(cell_idx, bit) for bit in range(width))

def alloc_late_value(self, width: int):
self.last_late_net -= width
return Value(Net.from_late(self.last_late_net + bit) for bit in range(width))
def alloc_late_value(self, signal: _ast.Signal):
self.last_late_net -= len(signal)
value = Value(Net.from_late(self.last_late_net + bit) for bit in range(len(signal)))
for bit, net in enumerate(value):
self.late_to_signal[net] = signal, bit
return value

@property
def top(self):
top = self.cells[0]
assert isinstance(top, Top)
return top

def check_comb_cycles(self):
class Cycle:
def __init__(self, start):
self.start = start
self.path = []

checked = set()
busy = set()

def traverse(net):
if net in checked:
return None

if net in busy:
return Cycle(net)
busy.add(net)

cycle = None
if net.is_const:
pass
elif net.is_late:
cycle = traverse(self.connections[net])
if cycle is not None:
sig, bit = self.late_to_signal[net]
cycle.path.append((sig, bit, sig.src_loc))
else:
for src, src_loc in self.cells[net.cell].comb_edges_to(net.bit):
cycle = traverse(src)
if cycle is not None:
cycle.path.append((self.cells[net.cell], net.bit, src_loc))
break

if cycle is not None and cycle.start == net:
msg = ["Combinational cycle detected, path:\n"]
for obj, bit, src_loc in reversed(cycle.path):
if isinstance(obj, _ast.Signal):
obj = f"signal {obj.name}"
elif isinstance(obj, Operator):
obj = f"operator {obj.operator}"
else:
obj = f"cell {obj.__class__.__name__}"
src_loc = "<unknown>:0" if src_loc is None else f"{src_loc[0]}:{src_loc[1]}"
msg.append(f" {src_loc}: {obj} bit {bit}\n")
raise CombinationalCycle("".join(msg))

busy.remove(net)
checked.add(net)
return cycle

for cell_idx, cell in enumerate(self.cells):
for net in cell.output_nets(cell_idx):
assert traverse(net) is None
for value in self.signals.values():
for net in value:
assert traverse(net) is None


class ModuleNetFlow(enum.Enum):
"""Describes how a given Net flows into or out of a Module.
Expand Down Expand Up @@ -509,6 +575,9 @@ def io_nets(self):
def resolve_nets(self, netlist: Netlist):
raise NotImplementedError

def comb_edges_to(self, bit: int) -> "Iterable[(Net, Any)]":
raise NotImplementedError


class Top(Cell):
"""A special cell type representing top-level non-IO ports. Must be present in the netlist exactly
Expand Down Expand Up @@ -558,6 +627,9 @@ def __repr__(self):
ports = "".join(ports)
return f"(top{ports})"

def comb_edges_to(self, bit):
return []


class Operator(Cell):
"""Roughly corresponds to ``hdl.ast.Operator``.
Expand Down Expand Up @@ -627,6 +699,28 @@ def __repr__(self):
inputs = " ".join(repr(input) for input in self.inputs)
return f"({self.operator} {inputs})"

def comb_edges_to(self, bit):
if len(self.inputs) == 1:
if self.operator == "~":
yield (self.inputs[0][bit], self.src_loc)
else:
for net in self.inputs[0]:
yield (net, self.src_loc)
elif len(self.inputs) == 2:
if self.operator in ("&", "|", "^"):
yield (self.inputs[0][bit], self.src_loc)
yield (self.inputs[1][bit], self.src_loc)
else:
for net in self.inputs[0]:
yield (net, self.src_loc)
for net in self.inputs[1]:
yield (net, self.src_loc)
else:
assert self.operator == "m"
yield (self.inputs[0][0], self.src_loc)
yield (self.inputs[1][bit], self.src_loc)
yield (self.inputs[2][bit], self.src_loc)


class Part(Cell):
"""Corresponds to ``hdl.ast.Part``.
Expand Down Expand Up @@ -666,6 +760,12 @@ def __repr__(self):
value_signed = "signed" if self.value_signed else "unsigned"
return f"(part {self.value} {value_signed} {self.offset} {self.width} {self.stride})"

def comb_edges_to(self, bit):
for net in self.value:
yield (net, self.src_loc)
for net in self.offset:
yield (net, self.src_loc)


class Matches(Cell):
"""A combinatorial cell performing a comparison like ``Value.matches``
Expand Down Expand Up @@ -698,6 +798,10 @@ def __repr__(self):
patterns = " ".join(self.patterns)
return f"(matches {self.value} {patterns})"

def comb_edges_to(self, bit):
for net in self.value:
yield (net, self.src_loc)


class PriorityMatch(Cell):
"""Used to represent a single switch on the control plane of processes.
Expand Down Expand Up @@ -733,6 +837,11 @@ def resolve_nets(self, netlist: Netlist):
def __repr__(self):
return f"(priority_match {self.en} {self.inputs})"

def comb_edges_to(self, bit):
yield (self.en, self.src_loc)
for net in self.inputs[:bit + 1]:
yield (net, self.src_loc)


class Assignment:
"""A single assignment in an ``AssignmentList``.
Expand Down Expand Up @@ -809,6 +918,13 @@ def __repr__(self):
assignments = " ".join(repr(assign) for assign in self.assignments)
return f"(assignment_list {self.default} {assignments})"

def comb_edges_to(self, bit):
yield (self.default[bit], self.src_loc)
for assign in self.assignments:
yield (assign.cond, assign.src_loc)
if bit >= assign.start and bit < assign.start + len(assign.value):
yield (assign.value[bit - assign.start], assign.src_loc)


class FlipFlop(Cell):
"""A flip-flop. ``data`` is the data input. ``init`` is the initial and async reset value.
Expand Down Expand Up @@ -853,6 +969,10 @@ def __repr__(self):
attributes = "".join(f" (attr {key} {val!r})" for key, val in self.attributes.items())
return f"(flipflop {self.data} {self.init} {self.clk_edge} {self.clk} {self.arst}{attributes})"

def comb_edges_to(self, bit):
yield (self.clk, self.src_loc)
yield (self.arst, self.src_loc)


class Memory(Cell):
"""Corresponds to ``Memory``. ``init`` must have length equal to ``depth``.
Expand Down Expand Up @@ -960,6 +1080,10 @@ def resolve_nets(self, netlist: Netlist):
def __repr__(self):
return f"(read_port {self.memory} {self.width} {self.addr})"

def comb_edges_to(self, bit):
for net in self.addr:
yield (net, self.src_loc)


class SyncReadPort(Cell):
"""A single synchronous read port of a memory. The cell output is the data port.
Expand Down Expand Up @@ -1004,6 +1128,9 @@ def __repr__(self):
transparent_for = " ".join(str(port) for port in self.transparent_for)
return f"(read_port {self.memory} {self.width} {self.addr} {self.en} {self.clk_edge} {self.clk} ({transparent_for}))"

def comb_edges_to(self, bit):
return []


class AsyncPrint(Cell):
"""Corresponds to ``Print`` in the "comb" domain.
Expand Down Expand Up @@ -1087,6 +1214,9 @@ def resolve_nets(self, netlist: Netlist):
def __repr__(self):
return f"(initial)"

def comb_edges_to(self, bit):
return []


class AnyValue(Cell):
"""Corresponds to ``AnyConst`` or ``AnySeq``. ``kind`` must be either ``'anyconst'``
Expand Down Expand Up @@ -1117,6 +1247,9 @@ def resolve_nets(self, netlist: Netlist):
def __repr__(self):
return f"({self.kind} {self.width})"

def comb_edges_to(self, bit):
return []


class AsyncProperty(Cell):
"""Corresponds to ``Assert``, ``Assume``, or ``Cover`` in the "comb" domain.
Expand Down Expand Up @@ -1274,6 +1407,10 @@ def __repr__(self):
items = " ".join(items)
return f"(instance {self.type!r} {self.name!r} {items})"

def comb_edges_to(self, bit):
# don't ask me, I'm a housecat
return []


class IOBuffer(Cell):
"""An IO buffer cell. This cell does two things:
Expand Down Expand Up @@ -1328,3 +1465,8 @@ def __repr__(self):
return f"(iob {self.dir.value} {self.port})"
else:
return f"(iob {self.dir.value} {self.port} {self.o} {self.oe})"

def comb_edges_to(self, bit):
if self.dir is not IODirection.Input:
yield (self.o[bit], self.src_loc)
yield (self.oe, self.src_loc)
21 changes: 20 additions & 1 deletion tests/test_hdl_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from amaranth.hdl._dsl import *
from amaranth.hdl._ir import *
from amaranth.hdl._mem import *
from amaranth.hdl._nir import SignalField
from amaranth.hdl._nir import SignalField, CombinationalCycle

from amaranth.lib import enum, data

Expand Down Expand Up @@ -3542,3 +3542,22 @@ class MyEnum(enum.Enum, shape=unsigned(2)):
self.assertEqual(nl.signal_fields[s4], {
(): SignalField(nl.signals[s4], signed=False),
})


class CycleTestCase(FHDLTestCase):
def test_cycle(self):
a = Signal()
b = Signal()
m = Module()
m.d.comb += [
a.eq(~b),
b.eq(~a),
]
with self.assertRaisesRegex(CombinationalCycle,
r"^Combinational cycle detected, path:\n"
r".*test_hdl_ir.py:\d+: operator ~ bit 0\n"
r".*test_hdl_ir.py:\d+: signal b bit 0\n"
r".*test_hdl_ir.py:\d+: operator ~ bit 0\n"
r".*test_hdl_ir.py:\d+: signal a bit 0\n"
r"$"):
build_netlist(Fragment.get(m, None), [])