Skip to content

Commit 51e0262

Browse files
tilkwhitequark
authored andcommitted
sim: group signal traces according to their function.
1 parent 89eae72 commit 51e0262

File tree

3 files changed

+109
-28
lines changed

3 files changed

+109
-28
lines changed

amaranth/sim/core.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,17 +212,25 @@ def write_vcd(self, vcd_file, gtkw_file=None, *, traces=(), fs_per_delta=0):
212212
file.close()
213213
raise ValueError("Cannot start writing waveforms after advancing simulation time")
214214

215-
for trace in traces:
216-
if isinstance(trace, ValueLike):
217-
trace_cast = Value.cast(trace)
215+
def traverse_traces(traces):
216+
if isinstance(traces, ValueLike):
217+
trace_cast = Value.cast(traces)
218218
if isinstance(trace_cast, MemoryData._Row):
219-
continue
219+
return
220220
for trace_signal in trace_cast._rhs_signals():
221221
if trace_signal.name == "":
222-
if trace_signal is trace:
222+
if trace_signal is traces:
223223
raise TypeError("Cannot trace signal with private name")
224224
else:
225-
raise TypeError(f"Cannot trace signal with private name (within {trace!r})")
225+
raise TypeError(f"Cannot trace signal with private name (within {traces!r})")
226+
elif isinstance(traces, (list, tuple)):
227+
for trace in traces:
228+
traverse_traces(trace)
229+
elif isinstance(traces, dict):
230+
for trace in traces.values():
231+
traverse_traces(trace)
232+
233+
traverse_traces(traces)
226234

227235
return self._engine.write_vcd(vcd_file=vcd_file, gtkw_file=gtkw_file,
228236
traces=traces, fs_per_delta=fs_per_delta)

amaranth/sim/pysim.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import enum as py_enum
66

77
from ..hdl import *
8+
from ..hdl._mem import MemoryInstance
89
from ..hdl._ast import SignalDict
10+
from ..lib import data, wiring
911
from ._base import *
1012
from ._async import *
1113
from ._pyeval import eval_format, eval_value, eval_assign
@@ -49,7 +51,7 @@ def __init__(self, state, design, *, vcd_file, gtkw_file=None, traces=(), fs_per
4951
self.gtkw_file = gtkw_file
5052
self.gtkw_save = gtkw_file and vcd.gtkw.GTKWSave(self.gtkw_file)
5153

52-
self.traces = []
54+
self.traces = traces
5355

5456
signal_names = SignalDict()
5557
memories = {}
@@ -64,9 +66,9 @@ def __init__(self, state, design, *, vcd_file, gtkw_file=None, traces=(), fs_per
6466

6567
trace_names = SignalDict()
6668
assigned_names = set()
67-
for trace in traces:
68-
if isinstance(trace, ValueLike):
69-
trace = Value.cast(trace)
69+
def traverse_traces(traces):
70+
if isinstance(traces, ValueLike):
71+
trace = Value.cast(traces)
7072
if isinstance(trace, MemoryData._Row):
7173
memory = trace._memory
7274
if not memory in memories:
@@ -77,7 +79,6 @@ def __init__(self, state, design, *, vcd_file, gtkw_file=None, traces=(), fs_per
7779
assert name not in assigned_names
7880
memories[memory] = ("bench", name)
7981
assigned_names.add(name)
80-
self.traces.append(trace)
8182
else:
8283
for trace_signal in trace._rhs_signals():
8384
if trace_signal not in signal_names:
@@ -88,19 +89,27 @@ def __init__(self, state, design, *, vcd_file, gtkw_file=None, traces=(), fs_per
8889
assert name not in assigned_names
8990
trace_names[trace_signal] = {("bench", name)}
9091
assigned_names.add(name)
91-
self.traces.append(trace_signal)
92-
elif isinstance(trace, MemoryData):
93-
if not trace in memories:
94-
if trace.name not in assigned_names:
95-
name = trace.name
92+
elif isinstance(traces, MemoryData):
93+
if not traces in memories:
94+
if traces.name not in assigned_names:
95+
name = traces.name
9696
else:
97-
name = f"{trace.name}${len(assigned_names)}"
97+
name = f"{traces.name}${len(assigned_names)}"
9898
assert name not in assigned_names
99-
memories[trace] = ("bench", name)
99+
memories[traces] = ("bench", name)
100100
assigned_names.add(name)
101-
self.traces.append(trace)
101+
elif hasattr(traces, "signature") and isinstance(traces.signature, wiring.Signature):
102+
for name in traces.signature.members:
103+
traverse_traces(getattr(traces, name))
104+
elif isinstance(traces, list) or isinstance(traces, tuple):
105+
for trace in traces:
106+
traverse_traces(trace)
107+
elif isinstance(traces, dict):
108+
for trace in traces.values():
109+
traverse_traces(trace)
102110
else:
103-
raise TypeError(f"{trace!r} is not a traceable object")
111+
raise TypeError(f"{traces!r} is not a traceable object")
112+
traverse_traces(traces)
104113

105114
if self.vcd_writer is None:
106115
return
@@ -277,19 +286,40 @@ def close(self, timestamp):
277286
self.gtkw_save.dumpfile_size(self.vcd_file.tell())
278287

279288
self.gtkw_save.treeopen("top")
280-
for trace in self.traces:
281-
if isinstance(trace, Signal):
282-
for name in self.gtkw_signal_names[trace]:
289+
290+
def traverse_traces(traces):
291+
if isinstance(traces, Signal):
292+
for name in self.gtkw_signal_names[traces]:
283293
self.gtkw_save.trace(name)
284-
elif isinstance(trace, MemoryData):
285-
for row_names in self.gtkw_memory_names[trace]:
294+
elif isinstance(traces, data.View):
295+
with self.gtkw_save.group("view"):
296+
trace = Value.cast(traces)
297+
for trace_signal in trace._rhs_signals():
298+
for name in self.gtkw_signal_names[trace_signal]:
299+
self.gtkw_save.trace(name)
300+
elif isinstance(traces, ValueLike):
301+
traverse_traces(Value.cast(traces))
302+
elif isinstance(traces, MemoryData):
303+
for row_names in self.gtkw_memory_names[traces]:
286304
for name in row_names:
287305
self.gtkw_save.trace(name)
288-
elif isinstance(trace, MemoryData._Row):
289-
for name in self.gtkw_memory_names[trace._memory][trace._index]:
306+
elif isinstance(traces, MemoryData._Row):
307+
for name in self.gtkw_memory_names[traces._memory][traces._index]:
290308
self.gtkw_save.trace(name)
309+
elif hasattr(traces, "signature") and isinstance(traces.signature, wiring.Signature):
310+
with self.gtkw_save.group("interface"):
311+
for _, _, member in traces.signature.flatten(traces):
312+
traverse_traces(member)
313+
elif isinstance(traces, list) or isinstance(traces, tuple):
314+
for trace in traces:
315+
traverse_traces(trace)
316+
elif isinstance(traces, dict):
317+
for name, trace in traces.items():
318+
with self.gtkw_save.group(name):
319+
traverse_traces(trace)
291320
else:
292321
assert False # :nocov:
322+
traverse_traces(self.traces)
293323

294324
if self.close_vcd:
295325
self.vcd_file.close()

tests/test_sim.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from amaranth.sim import *
1717
from amaranth.sim._pyeval import eval_format
1818
from amaranth.lib.memory import Memory
19-
from amaranth.lib import enum, data
19+
from amaranth.lib import enum, data, wiring
2020

2121
from .utils import *
2222
from amaranth._utils import _ignore_deprecated
@@ -1393,6 +1393,49 @@ def testbench():
13931393
sim.add_testbench(testbench)
13941394

13951395

1396+
class SimulatorTracesTestCase(FHDLTestCase):
1397+
def assertDef(self, traces, flat_traces):
1398+
frag = Fragment()
1399+
1400+
def process():
1401+
yield Delay(1e-6)
1402+
1403+
sim = Simulator(frag)
1404+
sim.add_testbench(process)
1405+
with sim.write_vcd("test.vcd", "test.gtkw", traces=traces):
1406+
sim.run()
1407+
1408+
def test_signal(self):
1409+
a = Signal()
1410+
self.assertDef(a, [a])
1411+
1412+
def test_list(self):
1413+
a = Signal()
1414+
self.assertDef([a], [a])
1415+
1416+
def test_tuple(self):
1417+
a = Signal()
1418+
self.assertDef((a,), [a])
1419+
1420+
def test_dict(self):
1421+
a = Signal()
1422+
self.assertDef({"a": a}, [a])
1423+
1424+
def test_struct_view(self):
1425+
a = Signal(data.StructLayout({"a": 1, "b": 3}))
1426+
self.assertDef(a, [a])
1427+
1428+
def test_interface(self):
1429+
sig = wiring.Signature({
1430+
"a": wiring.In(1),
1431+
"b": wiring.Out(3),
1432+
"c": wiring.Out(2).array(4),
1433+
"d": wiring.In(wiring.Signature({"e": wiring.In(5)}))
1434+
})
1435+
a = sig.create()
1436+
self.assertDef(a, [a])
1437+
1438+
13961439
class SimulatorRegressionTestCase(FHDLTestCase):
13971440
def test_bug_325(self):
13981441
dut = Module()

0 commit comments

Comments
 (0)