Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for monitor with product type arguments #299

Merged
merged 6 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ vnc_logs
coverage.xml
parser.out
parsetab.py
.ast_tools
2 changes: 1 addition & 1 deletion fault/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
from fault.sva import sva
from fault.assert_immediate import assert_immediate, assert_final
from fault.expression import abs, min, max, signed, integer
from fault.pysv import PysvMonitor
from fault.pysv import PysvMonitor, python_monitor
110 changes: 110 additions & 0 deletions fault/pysv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,117 @@
from abc import abstractmethod, ABC, ABCMeta
import inspect
from pysv import sv

from ast_tools.stack import SymbolTable
from ast_tools.passes import Pass, PASS_ARGS_T, apply_passes
from ast_tools.common import to_module
import typing as tp

import libcst as cst
import libcst.matchers as match

import magma as m


def _gen_product_args(base_name, T):
"""
Returns:
* Flat arguments used for the observe function
* Nested arguments used to reconstruct the original argument using
SimpleNamespace for dot notation
"""
if not issubclass(T, m.Product):
return [base_name], base_name
flat_args = []
nested_args = []
for elem, value in T.field_dict.items():
_flat_args, _nested_arg = _gen_product_args(
base_name + "_" + elem, value
)
flat_args += _flat_args
nested_args.append(f"{elem}={_nested_arg}")
return flat_args, f"SimpleNamespace({', '.join(nested_args)})"


class PysvMonitor(ABC, metaclass=ABCMeta):
@abstractmethod
def observe(self, *args, **kwargs):
pass


class MonitorTransformer(cst.CSTTransformer):
def __init__(self, env, metadata):
self.env = env
self.metadata = metadata

def leave_FunctionDef(self, original_node, updated_node):
"""
Transforms the observe function to accept the flattened version of
arguments (e.g. product leaf ports). Inserts code into the beginning
of the function to reconstruct the original argument using
SimpleNamespace to provide dot notation
"""
if match.matches(updated_node.name,
match.Name("observe")):
params = updated_node.params.params
assert match.matches(params[0], match.Param(match.Name("self")))
# Keep around the original arguments since this will be needed by
# the monitor peek logic (requires the unflattened arguments
self.metadata["_orig_observe_args_"] = [param.name.value
for param in params]
new_params = [params[0]]
prelude = [
cst.parse_statement("from types import SimpleNamespace")
]
for param in params[1:]:
if param.annotation is None:
new_params.append(param)
else:
T = eval(to_module(param.annotation.annotation).code,
dict(self.env))
if not issubclass(T, m.Product):
raise NotImplementedError()
flat_args, nested_args = _gen_product_args(
param.name.value, T)
new_params.extend(
cst.Param(cst.Name(arg)) for arg in flat_args)
prelude.append(
cst.parse_statement(
f"{param.name.value} = {nested_args}"))
return updated_node.with_changes(
params=updated_node.params.with_changes(params=new_params),
body=updated_node.body.with_changes(
body=tuple(prelude) + updated_node.body.body)
)
return updated_node


class MonitorPass(Pass):
def rewrite(self,
tree: cst.CSTNode,
env: SymbolTable,
metadata: tp.MutableMapping) -> PASS_ARGS_T:
tree = tree.visit(MonitorTransformer(env, metadata))
return tree, env, metadata


class python_monitor(apply_passes):
def __init__(self, pre_passes=[], post_passes=[],
debug: bool = False,
env: tp.Optional[SymbolTable] = None,
path: tp.Optional[str] = None,
file_name: tp.Optional[str] = None
):
passes = pre_passes + [MonitorPass()] + post_passes
super().__init__(passes=passes, env=env, debug=debug, path=path,
file_name=file_name)

def exec(self,
etree: tp.Union[cst.ClassDef, cst.FunctionDef],
stree: tp.Union[cst.ClassDef, cst.FunctionDef],
env: SymbolTable,
metadata: tp.MutableMapping):
result = super().exec(etree, stree, env, metadata)
result.observe._orig_args_ = metadata["_orig_observe_args_"]
result._source_code_ = to_module(stree).code
return result
2 changes: 1 addition & 1 deletion fault/system_verilog_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def process_peek(self, value):
self.disable_ndarray
)
return f"dut.{path}"
return f"{value.port.name}"
return f"{verilog_name(value.port.name, self.disable_ndarray)}"

def make_var(self, i, action):
if isinstance(action._type, AbstractBitVectorMeta):
Expand Down
31 changes: 19 additions & 12 deletions fault/tester/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from fault.pysv import PysvMonitor
import fault.actions as actions

import magma as m


@add_control_structures
class SynchronousTester(StagedTester):
Expand All @@ -19,20 +21,25 @@ def __init__(self, *args, **kwargs):
def eval(self):
raise TypeError("Cannot eval with synchronous tester")

def advance_cycle(self):
self.step(1)
def _flat_peek(self, value):
if isinstance(value, m.Product):
return sum((self._flat_peek(elem) for elem in value), [])
if (isinstance(value, m.Array) and
not issubclass(value.T, m.Digital)):
raise NotImplementedError()
return [self.peek(value)]

def _call_monitors(self):
for monitor in self.monitors:
argspec = inspect.getfullargspec(monitor.observe.func_def.func)
assert argspec.args[0] == "self", "Expected self as first arg"
args = [self.peek(getattr(self._circuit, arg))
for arg in argspec.args[1:]]
assert argspec.varargs is None, "Unsupported"
assert argspec.varkw is None, "Unsupported"
assert argspec.kwonlyargs == [], "Unsupported"
assert argspec.kwonlydefaults is None, "Unsupported"
assert argspec.annotations == {}, "Unsupported"
assert argspec.defaults is None, "Unsupported"
args = monitor.observe._orig_args_
assert args[0] == "self"
args = sum((self._flat_peek(getattr(self._circuit, arg))
for arg in args[1:]), [])
self.make_call_stmt(monitor.observe, *args)

def advance_cycle(self):
self.step(1)
self._call_monitors()
self.step(1)

def make_target(self, target, **kwargs):
Expand Down
42 changes: 42 additions & 0 deletions tests/test_pysv.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class DelayedDUT(m.Circuit):
io.O @= m.Register(m.Bits[4], has_enable=True)()(dut()(io.A, io.B),
CE=io.CE)

@fault.python_monitor()
class Monitor(fault.PysvMonitor):
@sv()
def __init__(self):
Expand Down Expand Up @@ -142,5 +143,46 @@ def test(circuit, enable):
test(DelayedDUT, 0)


def test_monitor_product(target, simulator):
class T(m.Product):
A = m.In(m.Bits[4])
B = m.In(m.Bits[4])

class DelayedDUTProduct(m.Circuit):
io = m.IO(I=T, O=m.Out(m.Bits[4]))
io += m.ClockIO(has_enable=True)
io.O @= m.Register(m.Bits[4], has_enable=True)()(dut()(io.I.A, io.I.B),
CE=io.CE)

@fault.python_monitor()
class ProductMonitor(fault.PysvMonitor):
@sv()
def __init__(self):
self.value = None

@sv()
def observe(self, I: T, O):
if self.value is not None:
assert O == self.value, f"{O} != {self.value}"
self.value = BitVector[4](I.A) + BitVector[4](I.B)
print(f"next value {self.value}")

tester = fault.SynchronousTester(DelayedDUTProduct)
monitor = tester.Var("monitor", ProductMonitor)
# TODO: Need clock to start at 1 for proper semantics
tester.poke(DelayedDUTProduct.CLK, 1)
tester.poke(monitor, tester.make_call_expr(ProductMonitor))
tester.attach_monitor(monitor)
tester.poke(DelayedDUTProduct.CE, 1)

for i in range(4):
tester.poke(tester.circuit.I, (BitVector.random(4),
BitVector.random(4)))
tester.advance_cycle()
tester.advance_cycle()

run_tester(tester, target, simulator)


if __name__ == "__main__":
test_class("verilator", None)