Skip to content

hdl._mem: implement MemoryData._Row from RFC 62. #1271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 95 additions & 11 deletions amaranth/hdl/_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,90 @@
from ._ir import *
from ._cd import *
from ._xfrm import *
from ._mem import MemoryData


__all__ = ["SyntaxError", "SyntaxWarning", "Module"]


class _Visitor:
def __init__(self):
self.driven_signals = SignalSet()

def visit_stmt(self, stmt):
if isinstance(stmt, _StatementList):
for s in stmt:
self.visit_stmt(s)
elif isinstance(stmt, Assign):
self.visit_lhs(stmt.lhs)
self.visit_rhs(stmt.rhs)
elif isinstance(stmt, Print):
for chunk in stmt.message._chunks:
if not isinstance(chunk, str):
obj, format_spec = chunk
self.visit_rhs(obj)
elif isinstance(stmt, Property):
self.visit_rhs(stmt.test)
if stmt.message is not None:
for chunk in stmt.message._chunks:
if not isinstance(chunk, str):
obj, format_spec = chunk
self.visit_rhs(obj)
elif isinstance(stmt, Switch):
self.visit_rhs(stmt.test)
for _patterns, stmts, _src_loc in stmt.cases:
self.visit_stmt(stmts)
elif isinstance(stmt, _LateBoundStatement):
pass
else:
assert False # :nocov:

def visit_lhs(self, value):
if isinstance(value, Operator) and value.operator in ("u", "s"):
self.visit_lhs(value.operands[0])
elif isinstance(value, (Signal, ClockSignal, ResetSignal)):
self.driven_signals.add(value)
elif isinstance(value, Slice):
self.visit_lhs(value.value)
elif isinstance(value, Part):
self.visit_lhs(value.value)
self.visit_rhs(value.offset)
elif isinstance(value, Concat):
for part in value.parts:
self.visit_lhs(part)
elif isinstance(value, SwitchValue):
self.visit_rhs(value.test)
for _patterns, elem in value.cases:
self.visit_lhs(elem)
elif isinstance(value, MemoryData._Row):
raise ValueError(f"Value {value!r} can only be used in simulator processes")
else:
raise ValueError(f"Value {value!r} cannot be assigned to")

def visit_rhs(self, value):
if isinstance(value, (Const, Signal, ClockSignal, ResetSignal, Initial, AnyValue)):
pass
elif isinstance(value, Operator):
for op in value.operands:
self.visit_rhs(op)
elif isinstance(value, Slice):
self.visit_rhs(value.value)
elif isinstance(value, Part):
self.visit_rhs(value.value)
self.visit_rhs(value.offset)
elif isinstance(value, Concat):
for part in value.parts:
self.visit_rhs(part)
elif isinstance(value, SwitchValue):
self.visit_rhs(value.test)
for _patterns, elem in value.cases:
self.visit_rhs(elem)
elif isinstance(value, MemoryData._Row):
raise ValueError(f"Value {value!r} can only be used in simulator processes")
else:
assert False # :nocov:


class _ModuleBuilderProxy:
def __init__(self, builder, depth):
object.__setattr__(self, "_builder", builder)
Expand Down Expand Up @@ -545,15 +624,16 @@ def _add_statement(self, assigns, domain, depth):

stmt._MustUse__used = True

if isinstance(stmt, Assign):
for signal in stmt._lhs_signals():
if signal not in self._driving:
self._driving[signal] = domain
elif self._driving[signal] != domain:
cd_curr = self._driving[signal]
raise SyntaxError(
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
f"already driven from d.{cd_curr}")
visitor = _Visitor()
visitor.visit_stmt(stmt)
for signal in visitor.driven_signals:
if signal not in self._driving:
self._driving[signal] = domain
elif self._driving[signal] != domain:
cd_curr = self._driving[signal]
raise SyntaxError(
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
f"already driven from d.{cd_curr}")

self._statements.setdefault(domain, []).append(stmt)

Expand Down Expand Up @@ -595,10 +675,14 @@ def elaborate(self, platform):
for domain, statements in self._statements.items():
statements = resolve_statements(statements)
fragment.add_statements(domain, statements)
for signal in statements._lhs_signals():
visitor = _Visitor()
visitor.visit_stmt(statements)
for signal in visitor.driven_signals:
fragment.add_driver(signal, domain)
fragment.add_statements("comb", self._top_comb_statements)
for signal in self._top_comb_statements._lhs_signals():
visitor = _Visitor()
visitor.visit_stmt(self._top_comb_statements)
for signal in visitor.driven_signals:
fragment.add_driver(signal, "comb")
fragment.add_domains(self._domains.values())
fragment.generated.update(self._generated)
Expand Down
53 changes: 31 additions & 22 deletions amaranth/hdl/_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,28 @@ def __repr__(self):
return f"MemoryData.Init({self._elems!r}, shape={self._shape!r}, depth={self._depth})"


@final
class _Row(Value):
def __init__(self, memory, index, *, src_loc_at=0):
assert isinstance(memory, MemoryData)
self._memory = memory
self._index = operator.index(index)
assert self._index in range(memory.depth)
super().__init__(src_loc_at=src_loc_at)

def shape(self):
return Shape.cast(self._memory.shape)

def _lhs_signals(self):
# This value cannot ever appear in a design.
raise NotImplementedError # :nocov:

_rhs_signals = _lhs_signals

def __repr__(self):
return f"(memory-row {self._memory!r} {self._index})"


def __init__(self, *, shape, depth, init, src_loc_at=0):
# shape and depth validation is performed in MemoryData.Init()
self._shape = shape
Expand Down Expand Up @@ -137,26 +159,14 @@ def __repr__(self):
return f"(memory-data {self.name})"

def __getitem__(self, index):
"""Simulation only."""
return MemorySimRead(self, index)


class MemorySimRead:
def __init__(self, memory, addr):
assert isinstance(memory, MemoryData)
self._memory = memory
self._addr = Value.cast(addr)

def eq(self, value):
return MemorySimWrite(self._memory, self._addr, value)


class MemorySimWrite:
def __init__(self, memory, addr, data):
assert isinstance(memory, MemoryData)
self._memory = memory
self._addr = Value.cast(addr)
self._data = Value.cast(data)
index = operator.index(index)
if index not in range(self.depth):
raise IndexError(f"Index {index} is out of bounds (memory has {self.depth} rows)")
row = MemoryData._Row(self, index)
if isinstance(self.shape, ShapeCastable):
return self.shape(row)
else:
return row


class MemoryInstance(Fragment):
Expand Down Expand Up @@ -312,8 +322,7 @@ def write_port(self, *, src_loc_at=0, **kwargs):
return WritePort(self, src_loc_at=1 + src_loc_at, **kwargs)

def __getitem__(self, index):
"""Simulation only."""
return MemorySimRead(self._data, index)
return self._data[index]

def elaborate(self, platform):
f = MemoryInstance(data=self._data, attrs=self.attrs, src_loc=self.src_loc)
Expand Down
6 changes: 1 addition & 5 deletions amaranth/lib/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import MutableSequence

from ..hdl import MemoryData, MemoryInstance, Shape, ShapeCastable, Const
from ..hdl._mem import MemorySimRead, FrozenError
from ..hdl._mem import FrozenError
from ..utils import ceil_log2
from .._utils import final
from .. import tracer
Expand Down Expand Up @@ -194,10 +194,6 @@ def elaborate(self, platform):
transparent_for=transparent_for)
return instance

def __getitem__(self, index):
"""Simulation only."""
return self._data[index]


class ReadPort:
"""A read memory port.
Expand Down
2 changes: 1 addition & 1 deletion amaranth/sim/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class BaseMemoryState:
def read(self, addr):
raise NotImplementedError # :nocov:

def write(self, addr, value):
def write(self, addr, value, mask=None):
raise NotImplementedError # :nocov:


Expand Down
18 changes: 0 additions & 18 deletions amaranth/sim/_pycoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from ..hdl import *
from ..hdl._ast import Statement, Assign, SignalSet, ValueCastable
from ..hdl._mem import MemorySimRead, MemorySimWrite
from .core import Tick, Settle, Delay, Passive, Active
from ._base import BaseProcess, BaseMemoryState
from ._pyeval import eval_value, eval_assign
Expand Down Expand Up @@ -123,23 +122,6 @@ def run(self):
elif type(command) is Active:
self.passive = False

elif type(command) is MemorySimRead:
addr = eval_value(self.state, command._addr)
index = self.state.get_memory(command._memory)
state = self.state.slots[index]
assert isinstance(state, BaseMemoryState)
response = state.read(addr)

elif type(command) is MemorySimWrite:
addr = eval_value(self.state, command._addr)
data = eval_value(self.state, command._data)
index = self.state.get_memory(command._memory)
state = self.state.slots[index]
assert isinstance(state, BaseMemoryState)
state.write(addr, data)
if self.testbench:
return True # assignment; run a delta cycle

elif command is None: # only possible if self.default_cmd is None
raise TypeError("Received default command from process {!r} that was added "
"with add_process(); did you mean to use Tick() instead?"
Expand Down
13 changes: 13 additions & 0 deletions amaranth/sim/_pyeval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from amaranth.hdl._ast import *
from amaranth.hdl._mem import MemoryData


def _eval_matches(test, patterns):
Expand Down Expand Up @@ -118,6 +119,9 @@ def eval_value(sim, value):
elif isinstance(value, Signal):
slot = sim.get_signal(value)
return sim.slots[slot].curr
elif isinstance(value, MemoryData._Row):
slot = sim.get_memory(value._memory)
return sim.slots[slot].read(value._index)
elif isinstance(value, (ResetSignal, ClockSignal, AnyValue, Initial)):
raise ValueError(f"Value {value!r} cannot be used in simulation")
else:
Expand All @@ -142,6 +146,15 @@ def _eval_assign_inner(sim, lhs, lhs_start, rhs, rhs_len):
if lhs._signed and (value & (1 << (len(lhs) - 1))):
value |= -1 << (len(lhs) - 1)
sim.slots[slot].set(value)
elif isinstance(lhs, MemoryData._Row):
lhs_stop = lhs_start + rhs_len
if lhs_stop > len(lhs):
lhs_stop = len(lhs)
if lhs_start >= len(lhs):
return
slot = sim.get_memory(lhs._memory)
mask = (1 << lhs_stop) - (1 << lhs_start)
sim.slots[slot].write(lhs._index, rhs << lhs_start, mask)
elif isinstance(lhs, Slice):
_eval_assign_inner(sim, lhs.value, lhs_start + lhs.start, rhs, rhs_len)
elif isinstance(lhs, Concat):
Expand Down
3 changes: 3 additions & 0 deletions amaranth/sim/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..hdl._cd import *
from ..hdl._ir import *
from ..hdl._ast import Value, ValueLike
from ..hdl._mem import MemoryData
from ._base import BaseEngine


Expand Down Expand Up @@ -242,6 +243,8 @@ def write_vcd(self, vcd_file, gtkw_file=None, *, traces=(), fs_per_delta=0):
for trace in traces:
if isinstance(trace, ValueLike):
trace_cast = Value.cast(trace)
if isinstance(trace_cast, MemoryData._Row):
continue
for trace_signal in trace_cast._rhs_signals():
if trace_signal.name == "":
if trace_signal is trace:
Expand Down
39 changes: 27 additions & 12 deletions amaranth/sim/pysim.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,28 @@ def __init__(self, design, *, vcd_file, gtkw_file=None, traces=(), fs_per_delta=
for trace in traces:
if isinstance(trace, ValueLike):
trace = Value.cast(trace)
for trace_signal in trace._rhs_signals():
if trace_signal not in signal_names:
if trace_signal.name not in assigned_names:
name = trace_signal.name
if isinstance(trace, MemoryData._Row):
memory = trace._memory
if not memory in memories:
if memory.name not in assigned_names:
name = memory.name
else:
name = f"{trace_signal.name}${len(assigned_names)}"
name = f"{memory.name}${len(assigned_names)}"
assert name not in assigned_names
trace_names[trace_signal] = {("bench", name)}
memories[memory] = ("bench", name)
assigned_names.add(name)
self.traces.append(trace_signal)
self.traces.append(trace)
else:
for trace_signal in trace._rhs_signals():
if trace_signal not in signal_names:
if trace_signal.name not in assigned_names:
name = trace_signal.name
else:
name = f"{trace_signal.name}${len(assigned_names)}"
assert name not in assigned_names
trace_names[trace_signal] = {("bench", name)}
assigned_names.add(name)
self.traces.append(trace_signal)
elif isinstance(trace, MemoryData):
if not trace in memories:
if trace.name not in assigned_names:
Expand Down Expand Up @@ -223,13 +235,16 @@ def close(self, timestamp):
self.gtkw_save.dumpfile_size(self.vcd_file.tell())

self.gtkw_save.treeopen("top")
for signal in self.traces:
if isinstance(signal, Signal):
for name in self.gtkw_signal_names[signal]:
for trace in self.traces:
if isinstance(trace, Signal):
for name in self.gtkw_signal_names[trace]:
self.gtkw_save.trace(name)
elif isinstance(signal, MemoryIdentity):
for name in self.gtkw_memory_names[signal]:
elif isinstance(trace, MemoryData):
for name in self.gtkw_memory_names[trace]:
self.gtkw_save.trace(name)
elif isinstance(trace, MemoryData._Row):
name = self.gtkw_memory_names[trace._memory][trace._index]
self.gtkw_save.trace(name)
else:
assert False # :nocov:

Expand Down
1 change: 1 addition & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Implemented RFCs
* `RFC 51`_: Add ``ShapeCastable.from_bits`` and ``amaranth.lib.data.Const``
* `RFC 53`_: Low-level I/O primitives
* `RFC 59`_: Get rid of upwards propagation of clock domains
* `RFC 62`_: The `MemoryData`` class


Language changes
Expand Down
Loading