Skip to content

Commit c424d03

Browse files
committed
hdl._nir: add combinational cycle detection.
Fixes #704. Fixes #1143.
1 parent 8bf4f77 commit c424d03

File tree

4 files changed

+172
-7
lines changed

4 files changed

+172
-7
lines changed

amaranth/hdl/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment
1010
from ._ir import Instance, IOBufferInstance
1111
from ._mem import FrozenMemory, MemoryData, MemoryInstance, Memory, ReadPort, WritePort, DummyPort
12+
from ._nir import CombinationalCycle
1213
from ._rec import Record
1314
from ._xfrm import DomainRenamer, ResetInserter, EnableInserter
1415

@@ -28,6 +29,8 @@
2829
# _ir
2930
"UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment",
3031
"Instance", "IOBufferInstance",
32+
# _nir
33+
"CombinationalCycle",
3134
# _mem
3235
"FrozenMemory", "MemoryData", "MemoryInstance", "Memory", "ReadPort", "WritePort", "DummyPort",
3336
# _rec

amaranth/hdl/_ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def __init__(self, netlist: _nir.Netlist, design: Design, *, all_undef_to_ff=Fal
709709
def emit_signal(self, signal) -> _nir.Value:
710710
if signal in self.netlist.signals:
711711
return self.netlist.signals[signal]
712-
value = self.netlist.alloc_late_value(len(signal))
712+
value = self.netlist.alloc_late_value(signal)
713713
self.netlist.signals[signal] = value
714714
for bit, net in enumerate(value):
715715
self.late_net_to_signal[net] = (signal, bit)
@@ -1738,6 +1738,7 @@ def build_netlist(fragment, ports=(), *, name="top", all_undef_to_ff=False, **kw
17381738
design = fragment.prepare(ports=ports, hierarchy=(name,), **kwargs)
17391739
netlist = _nir.Netlist()
17401740
_emit_netlist(netlist, design, all_undef_to_ff=all_undef_to_ff)
1741+
netlist.check_comb_cycles()
17411742
netlist.resolve_all_nets()
17421743
_compute_net_flows(netlist)
17431744
_compute_ports(netlist)

amaranth/hdl/_nir.py

Lines changed: 147 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable
1+
from typing import Iterable, Any
22
import enum
33

44
from ._ast import SignalDict
@@ -7,8 +7,9 @@
77

88
__all__ = [
99
# Netlist core
10+
"CombinationalCycle",
1011
"Net", "Value", "IONet", "IOValue",
11-
"FormatValue", "Format",
12+
"FormatValue", "Format", "SignalField",
1213
"Netlist", "ModuleNetFlow", "IODirection", "Module", "Cell", "Top",
1314
# Computation cells
1415
"Operator", "Part",
@@ -25,6 +26,10 @@
2526
]
2627

2728

29+
class CombinationalCycle(Exception):
30+
pass
31+
32+
2833
class Net(int):
2934
__slots__ = ()
3035

@@ -335,6 +340,7 @@ class Netlist:
335340
modules : list of ``Module``
336341
cells : list of ``Cell``
337342
connections : dict of (negative) int to int
343+
late_to_signal : dict of (late) Net to its Signal and bit number
338344
io_ports : list of ``IOPort``
339345
signals : dict of Signal to ``Value``
340346
signal_fields: dict of Signal to dict of tuple[str | int] to SignalField
@@ -344,6 +350,7 @@ def __init__(self):
344350
self.modules: list[Module] = []
345351
self.cells: list[Cell] = [Top()]
346352
self.connections: dict[Net, Net] = {}
353+
self.late_to_signal: dict[Net, (_ast.Signal, int)] = {}
347354
self.io_ports: list[_ast.IOPort] = []
348355
self.signals = SignalDict()
349356
self.signal_fields = SignalDict()
@@ -405,16 +412,75 @@ def add_value_cell(self, width: int, cell):
405412
cell_idx = self.add_cell(cell)
406413
return Value(Net.from_cell(cell_idx, bit) for bit in range(width))
407414

408-
def alloc_late_value(self, width: int):
409-
self.last_late_net -= width
410-
return Value(Net.from_late(self.last_late_net + bit) for bit in range(width))
415+
def alloc_late_value(self, signal: _ast.Signal):
416+
self.last_late_net -= len(signal)
417+
value = Value(Net.from_late(self.last_late_net + bit) for bit in range(len(signal)))
418+
for bit, net in enumerate(value):
419+
self.late_to_signal[net] = signal, bit
420+
return value
411421

412422
@property
413423
def top(self):
414424
top = self.cells[0]
415425
assert isinstance(top, Top)
416426
return top
417427

428+
def check_comb_cycles(self):
429+
class Cycle:
430+
def __init__(self, start):
431+
self.start = start
432+
self.path = []
433+
434+
checked = set()
435+
busy = set()
436+
437+
def traverse(net):
438+
if net in checked:
439+
return None
440+
441+
if net in busy:
442+
return Cycle(net)
443+
busy.add(net)
444+
445+
cycle = None
446+
if net.is_const:
447+
pass
448+
elif net.is_late:
449+
cycle = traverse(self.connections[net])
450+
if cycle is not None:
451+
sig, bit = self.late_to_signal[net]
452+
cycle.path.append((sig, bit, sig.src_loc))
453+
else:
454+
for src, src_loc in self.cells[net.cell].comb_edges_to(net.bit):
455+
cycle = traverse(src)
456+
if cycle is not None:
457+
cycle.path.append((self.cells[net.cell], net.bit, src_loc))
458+
break
459+
460+
if cycle is not None and cycle.start == net:
461+
msg = ["Combinational cycle detected, path:\n"]
462+
for obj, bit, src_loc in reversed(cycle.path):
463+
if isinstance(obj, _ast.Signal):
464+
obj = f"signal {obj.name}"
465+
elif isinstance(obj, Operator):
466+
obj = f"operator {obj.operator}"
467+
else:
468+
obj = f"cell {obj.__class__.__name__}"
469+
src_loc = "[unknown]" if src_loc is None else f"{src_loc[0]}:{src_loc[1]}"
470+
msg.append(f" {src_loc}: {obj} bit {bit}\n")
471+
raise CombinationalCycle("".join(msg)) from None
472+
473+
busy.remove(net)
474+
checked.add(net)
475+
return cycle
476+
477+
for cell_idx, cell in enumerate(self.cells):
478+
for net in cell.output_nets(cell_idx):
479+
assert traverse(net) is None
480+
for value in self.signals.values():
481+
for net in value:
482+
assert traverse(net) is None
483+
418484

419485
class ModuleNetFlow(enum.Enum):
420486
"""Describes how a given Net flows into or out of a Module.
@@ -509,6 +575,9 @@ def io_nets(self):
509575
def resolve_nets(self, netlist: Netlist):
510576
raise NotImplementedError
511577

578+
def comb_edges_to(self, bit: int) -> "Iterable[(Net, Any)]":
579+
raise NotImplementedError
580+
512581

513582
class Top(Cell):
514583
"""A special cell type representing top-level non-IO ports. Must be present in the netlist exactly
@@ -558,6 +627,9 @@ def __repr__(self):
558627
ports = "".join(ports)
559628
return f"(top{ports})"
560629

630+
def comb_edges_to(self, bit):
631+
return []
632+
561633

562634
class Operator(Cell):
563635
"""Roughly corresponds to ``hdl.ast.Operator``.
@@ -627,6 +699,28 @@ def __repr__(self):
627699
inputs = " ".join(repr(input) for input in self.inputs)
628700
return f"({self.operator} {inputs})"
629701

702+
def comb_edges_to(self, bit):
703+
if len(self.inputs) == 1:
704+
if self.operator == "~":
705+
yield (self.inputs[0][bit], self.src_loc)
706+
else:
707+
for net in self.inputs[0]:
708+
yield (net, self.src_loc)
709+
elif len(self.inputs) == 2:
710+
if self.operator in ("&", "|", "^"):
711+
yield (self.inputs[0][bit], self.src_loc)
712+
yield (self.inputs[1][bit], self.src_loc)
713+
else:
714+
for net in self.inputs[0]:
715+
yield (net, self.src_loc)
716+
for net in self.inputs[1]:
717+
yield (net, self.src_loc)
718+
else:
719+
assert self.operator == "m"
720+
yield (self.inputs[0][0], self.src_loc)
721+
yield (self.inputs[1][bit], self.src_loc)
722+
yield (self.inputs[2][bit], self.src_loc)
723+
630724

631725
class Part(Cell):
632726
"""Corresponds to ``hdl.ast.Part``.
@@ -666,6 +760,12 @@ def __repr__(self):
666760
value_signed = "signed" if self.value_signed else "unsigned"
667761
return f"(part {self.value} {value_signed} {self.offset} {self.width} {self.stride})"
668762

763+
def comb_edges_to(self, bit):
764+
for net in self.value:
765+
yield (net, self.src_loc)
766+
for net in self.offset:
767+
yield (net, self.src_loc)
768+
669769

670770
class Matches(Cell):
671771
"""A combinatorial cell performing a comparison like ``Value.matches``
@@ -698,6 +798,10 @@ def __repr__(self):
698798
patterns = " ".join(self.patterns)
699799
return f"(matches {self.value} {patterns})"
700800

801+
def comb_edges_to(self, bit):
802+
for net in self.value:
803+
yield (net, self.src_loc)
804+
701805

702806
class PriorityMatch(Cell):
703807
"""Used to represent a single switch on the control plane of processes.
@@ -733,6 +837,11 @@ def resolve_nets(self, netlist: Netlist):
733837
def __repr__(self):
734838
return f"(priority_match {self.en} {self.inputs})"
735839

840+
def comb_edges_to(self, bit):
841+
yield (self.en, self.src_loc)
842+
for net in self.inputs[:bit + 1]:
843+
yield (net, self.src_loc)
844+
736845

737846
class Assignment:
738847
"""A single assignment in an ``AssignmentList``.
@@ -809,6 +918,13 @@ def __repr__(self):
809918
assignments = " ".join(repr(assign) for assign in self.assignments)
810919
return f"(assignment_list {self.default} {assignments})"
811920

921+
def comb_edges_to(self, bit):
922+
yield (self.default[bit], self.src_loc)
923+
for assign in self.assignments:
924+
yield (assign.cond, assign.src_loc)
925+
if bit >= assign.start and bit < assign.start + len(assign.value):
926+
yield (assign.value[bit - assign.start], assign.src_loc)
927+
812928

813929
class FlipFlop(Cell):
814930
"""A flip-flop. ``data`` is the data input. ``init`` is the initial and async reset value.
@@ -853,6 +969,10 @@ def __repr__(self):
853969
attributes = "".join(f" (attr {key} {val!r})" for key, val in self.attributes.items())
854970
return f"(flipflop {self.data} {self.init} {self.clk_edge} {self.clk} {self.arst}{attributes})"
855971

972+
def comb_edges_to(self, bit):
973+
yield (self.clk, self.src_loc)
974+
yield (self.arst, self.src_loc)
975+
856976

857977
class Memory(Cell):
858978
"""Corresponds to ``Memory``. ``init`` must have length equal to ``depth``.
@@ -960,6 +1080,10 @@ def resolve_nets(self, netlist: Netlist):
9601080
def __repr__(self):
9611081
return f"(read_port {self.memory} {self.width} {self.addr})"
9621082

1083+
def comb_edges_to(self, bit):
1084+
for net in self.addr:
1085+
yield (net, self.src_loc)
1086+
9631087

9641088
class SyncReadPort(Cell):
9651089
"""A single synchronous read port of a memory. The cell output is the data port.
@@ -1004,6 +1128,9 @@ def __repr__(self):
10041128
transparent_for = " ".join(str(port) for port in self.transparent_for)
10051129
return f"(read_port {self.memory} {self.width} {self.addr} {self.en} {self.clk_edge} {self.clk} ({transparent_for}))"
10061130

1131+
def comb_edges_to(self, bit):
1132+
return []
1133+
10071134

10081135
class AsyncPrint(Cell):
10091136
"""Corresponds to ``Print`` in the "comb" domain.
@@ -1087,6 +1214,9 @@ def resolve_nets(self, netlist: Netlist):
10871214
def __repr__(self):
10881215
return f"(initial)"
10891216

1217+
def comb_edges_to(self, bit):
1218+
return []
1219+
10901220

10911221
class AnyValue(Cell):
10921222
"""Corresponds to ``AnyConst`` or ``AnySeq``. ``kind`` must be either ``'anyconst'``
@@ -1117,6 +1247,9 @@ def resolve_nets(self, netlist: Netlist):
11171247
def __repr__(self):
11181248
return f"({self.kind} {self.width})"
11191249

1250+
def comb_edges_to(self, bit):
1251+
return []
1252+
11201253

11211254
class AsyncProperty(Cell):
11221255
"""Corresponds to ``Assert``, ``Assume``, or ``Cover`` in the "comb" domain.
@@ -1274,6 +1407,10 @@ def __repr__(self):
12741407
items = " ".join(items)
12751408
return f"(instance {self.type!r} {self.name!r} {items})"
12761409

1410+
def comb_edges_to(self, bit):
1411+
# don't ask me, I'm a housecat
1412+
return []
1413+
12771414

12781415
class IOBuffer(Cell):
12791416
"""An IO buffer cell. This cell does two things:
@@ -1328,3 +1465,8 @@ def __repr__(self):
13281465
return f"(iob {self.dir.value} {self.port})"
13291466
else:
13301467
return f"(iob {self.dir.value} {self.port} {self.o} {self.oe})"
1468+
1469+
def comb_edges_to(self, bit):
1470+
if self.dir is not IODirection.Input:
1471+
yield (self.o[bit], self.src_loc)
1472+
yield (self.oe, self.src_loc)

tests/test_hdl_ir.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from amaranth.hdl._dsl import *
88
from amaranth.hdl._ir import *
99
from amaranth.hdl._mem import *
10-
from amaranth.hdl._nir import SignalField
10+
from amaranth.hdl._nir import SignalField, CombinationalCycle
1111

1212
from amaranth.lib import enum, data
1313

@@ -3542,3 +3542,22 @@ class MyEnum(enum.Enum, shape=unsigned(2)):
35423542
self.assertEqual(nl.signal_fields[s4], {
35433543
(): SignalField(nl.signals[s4], signed=False),
35443544
})
3545+
3546+
3547+
class CycleTestCase(FHDLTestCase):
3548+
def test_cycle(self):
3549+
a = Signal()
3550+
b = Signal()
3551+
m = Module()
3552+
m.d.comb += [
3553+
a.eq(~b),
3554+
b.eq(~a),
3555+
]
3556+
with self.assertRaisesRegex(CombinationalCycle,
3557+
r"^Combinational cycle detected, path:\n"
3558+
r".*test_hdl_ir.py:\d+: operator ~ bit 0\n"
3559+
r".*test_hdl_ir.py:\d+: signal b bit 0\n"
3560+
r".*test_hdl_ir.py:\d+: operator ~ bit 0\n"
3561+
r".*test_hdl_ir.py:\d+: signal a bit 0\n"
3562+
r"$"):
3563+
build_netlist(Fragment.get(m, None), [])

0 commit comments

Comments
 (0)