hdl.ast: add Past, Stable, Rose, Fell.

This commit is contained in:
whitequark 2019-01-17 04:31:27 +00:00
parent 16f90d3585
commit 8c96675580
6 changed files with 112 additions and 9 deletions

View file

@ -10,7 +10,8 @@ from ..tools import *
__all__ = [
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
"Array", "ArrayProxy", "Sample",
"Array", "ArrayProxy",
"Sample", "Past", "Stable", "Rose", "Fell",
"Signal", "ClockSignal", "ResetSignal",
"Statement", "Assign", "Assert", "Assume", "Switch", "Delay", "Tick",
"Passive", "ValueKey", "ValueDict", "ValueSet", "SignalKey", "SignalDict",
@ -471,10 +472,10 @@ class Cat(Value):
return sum(len(part) for part in self.parts), False
def _lhs_signals(self):
return union(part._lhs_signals() for part in self.parts)
return union((part._lhs_signals() for part in self.parts), start=ValueSet())
def _rhs_signals(self):
return union(part._rhs_signals() for part in self.parts)
return union((part._rhs_signals() for part in self.parts), start=ValueSet())
def _as_const(self):
value = 0
@ -832,9 +833,15 @@ class ArrayProxy(Value):
class Sample(Value):
def __init__(self, value, clocks, domain):
"""Value from the past.
A ``Sample`` of an expression is equal to the value of the expression ``clocks`` clock edges
of the ``domain`` clock back. If that moment is before the beginning of time, it is equal
to the value of the expression calculated as if each signal had its reset value.
"""
def __init__(self, expr, clocks, domain):
super().__init__(src_loc_at=1)
self.value = Value.wrap(value)
self.value = Value.wrap(expr)
self.clocks = int(clocks)
self.domain = domain
if not isinstance(self.value, (Const, Signal)):
@ -855,6 +862,22 @@ class Sample(Value):
self.value, "<default>" if self.domain is None else self.domain, self.clocks)
def Past(expr, clocks=1, domain=None):
return Sample(expr, clocks, domain)
def Stable(expr, clocks=0, domain=None):
return Sample(expr, clocks + 1, domain) == Sample(expr, clocks, domain)
def Rose(expr, clocks=0, domain=None):
return ~Sample(expr, clocks + 1, domain) & Sample(expr, clocks, domain)
def Fell(expr, clocks=0, domain=None):
return Sample(expr, clocks + 1, domain) & ~Sample(expr, clocks, domain)
class _StatementList(list):
def __repr__(self):
return "({})".format(" ".join(map(repr, self)))

View file

@ -353,6 +353,7 @@ class Module(_ModuleBuilderRoot):
"Only assignments, asserts, and assumes may be appended to d.{}"
.format(domain_name(domain)))
assign = SampleDomainInjector(domain)(assign)
for signal in assign._lhs_signals():
if signal not in self._driving:
self._driving[signal] = domain
@ -384,7 +385,8 @@ class Module(_ModuleBuilderRoot):
fragment = Fragment()
for submodule, name in self._submodules:
fragment.add_subfragment(submodule.get_fragment(platform), name)
fragment.add_statements(self._statements)
statements = SampleDomainInjector("sync")(self._statements)
fragment.add_statements(statements)
for signal, domain in self._driving.items():
fragment.add_driver(signal, domain)
fragment.add_domains(self._domains)

View file

@ -363,9 +363,9 @@ class Fragment:
SignalSet(self.iter_ports("io")))
def prepare(self, ports=(), ensure_sync_exists=True):
from .xfrm import FragmentTransformer
from .xfrm import SampleLowerer
fragment = FragmentTransformer()(self)
fragment = SampleLowerer()(self)
fragment._propagate_domains(ensure_sync_exists)
fragment._resolve_hierarchy_conflicts()
fragment = fragment._insert_domain_resets()

View file

@ -13,7 +13,8 @@ from .rec import *
__all__ = ["ValueVisitor", "ValueTransformer",
"StatementVisitor", "StatementTransformer",
"FragmentTransformer",
"DomainRenamer", "DomainLowerer", "SampleLowerer",
"DomainRenamer", "DomainLowerer",
"SampleDomainInjector", "SampleLowerer",
"SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
"ResetInserter", "CEInserter"]
@ -340,6 +341,19 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
return cd.rst
class SampleDomainInjector(ValueTransformer, StatementTransformer):
def __init__(self, domain):
self.domain = domain
def on_Sample(self, value):
if value.domain is not None:
return value
return Sample(value.value, value.clocks, self.domain)
def __call__(self, stmts):
return self.on_statement(stmts)
class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
def __init__(self):
self.sample_cache = ValueDict()

View file

@ -113,6 +113,24 @@ class DSLTestCase(FHDLTestCase):
)
""")
def test_sample_domain(self):
m = Module()
i = Signal()
o1 = Signal()
o2 = Signal()
o3 = Signal()
m.d.sync += o1.eq(Past(i))
m.d.pix += o2.eq(Past(i))
m.d.pix += o3.eq(Past(i, domain="sync"))
f = m.lower(platform=None)
self.assertRepr(f.statements, """
(
(eq (sig o1) (sample (sig i) @ sync[1]))
(eq (sig o2) (sample (sig i) @ pix[1]))
(eq (sig o3) (sample (sig i) @ sync[1]))
)
""")
def test_If(self):
m = Module()
with m.If(self.s1):

View file

@ -545,6 +545,52 @@ class SimulatorIntegrationTestCase(FHDLTestCase):
sim.add_clock(1e-6)
sim.add_sync_process(process)
def test_sample_helpers(self):
m = Module()
s = Signal(2)
def mk(x):
y = Signal.like(x)
m.d.comb += y.eq(x)
return y
p0, r0, f0, s0 = mk(Past(s, 0)), mk(Rose(s)), mk(Fell(s)), mk(Stable(s))
p1, r1, f1, s1 = mk(Past(s)), mk(Rose(s, 1)), mk(Fell(s, 1)), mk(Stable(s, 1))
p2, r2, f2, s2 = mk(Past(s, 2)), mk(Rose(s, 2)), mk(Fell(s, 2)), mk(Stable(s, 2))
p3, r3, f3, s3 = mk(Past(s, 3)), mk(Rose(s, 3)), mk(Fell(s, 3)), mk(Stable(s, 3))
with self.assertSimulation(m) as sim:
def process_gen():
yield s.eq(0b10)
yield
yield
yield s.eq(0b01)
yield
def process_check():
yield
yield
yield
self.assertEqual((yield p0), 0b01)
self.assertEqual((yield p1), 0b10)
self.assertEqual((yield p2), 0b10)
self.assertEqual((yield p3), 0b00)
self.assertEqual((yield s0), 0b0)
self.assertEqual((yield s1), 0b1)
self.assertEqual((yield s2), 0b0)
self.assertEqual((yield s3), 0b1)
self.assertEqual((yield r0), 0b01)
self.assertEqual((yield r1), 0b00)
self.assertEqual((yield r2), 0b10)
self.assertEqual((yield r3), 0b00)
self.assertEqual((yield f0), 0b10)
self.assertEqual((yield f1), 0b00)
self.assertEqual((yield f2), 0b00)
self.assertEqual((yield f3), 0b00)
sim.add_clock(1e-6)
sim.add_sync_process(process_gen)
sim.add_sync_process(process_check)
def test_wrong_not_run(self):
with self.assertWarns(UserWarning,
msg="Simulation created, but not run"):