hdl.xfrm: add SampleLowerer.

This commit is contained in:
whitequark 2019-01-17 01:41:02 +00:00
parent b3de114d67
commit 198efcad31
4 changed files with 110 additions and 13 deletions

View file

@ -74,6 +74,15 @@ normalize = Const.normalize
class _ValueCompiler(ValueVisitor):
def on_AnyConst(self, value):
raise NotImplementedError # :nocov:
def on_AnySeq(self, value):
raise NotImplementedError # :nocov:
def on_Sample(self, value):
raise NotImplementedError # :nocov:
def on_Record(self, value):
return self(Cat(value.fields.values()))
@ -87,12 +96,6 @@ class _RHSValueCompiler(_ValueCompiler):
def on_Const(self, value):
return lambda state: value.value
def on_AnyConst(self, value):
raise NotImplementedError # :nocov:
def on_AnySeq(self, value):
raise NotImplementedError # :nocov:
def on_Signal(self, value):
if self.sensitivity is not None:
self.sensitivity.add(value)
@ -225,12 +228,6 @@ class _LHSValueCompiler(_ValueCompiler):
def on_Const(self, value):
raise TypeError # :nocov:
def on_AnyConst(self, value):
raise TypeError # :nocov:
def on_AnySeq(self, value):
raise TypeError # :nocov:
def on_Signal(self, value):
shape = value.shape()
value_slot = self.signal_slots[value]

View file

@ -308,6 +308,9 @@ class _ValueCompiler(xfrm.ValueVisitor):
def on_ResetSignal(self, value):
raise NotImplementedError # :nocov:
def on_Sample(self, value):
raise NotImplementedError # :nocov:
def on_Record(self, value):
return self(ast.Cat(value.fields.values()))

View file

@ -13,7 +13,7 @@ from .rec import *
__all__ = ["ValueVisitor", "ValueTransformer",
"StatementVisitor", "StatementTransformer",
"FragmentTransformer",
"DomainRenamer", "DomainLowerer",
"DomainRenamer", "DomainLowerer", "SampleLowerer",
"SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
"ResetInserter", "CEInserter"]
@ -71,6 +71,10 @@ class ValueVisitor(metaclass=ABCMeta):
def on_ArrayProxy(self, value):
pass # :nocov:
@abstractmethod
def on_Sample(self, value):
pass # :nocov:
def on_unknown_value(self, value):
raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov:
@ -102,6 +106,8 @@ class ValueVisitor(metaclass=ABCMeta):
new_value = self.on_Repl(value)
elif type(value) is ArrayProxy:
new_value = self.on_ArrayProxy(value)
elif type(value) is Sample:
new_value = self.on_Sample(value)
else:
new_value = self.on_unknown_value(value)
if isinstance(new_value, Value):
@ -153,6 +159,9 @@ class ValueTransformer(ValueVisitor):
return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
self.on_value(value.index))
def on_Sample(self, value):
return Sample(self.on_value(value.value), value.clocks, value.domain)
class StatementVisitor(metaclass=ABCMeta):
@abstractmethod
@ -331,6 +340,48 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
return cd.rst
class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
def __init__(self):
self.sample_cache = ValueDict()
self.sample_stmts = OrderedDict()
def _name_reset(self, value):
if isinstance(value, Const):
return "c${}".format(value.value), value.value
elif isinstance(value, Signal):
return "s${}".format(value.name), value.reset
else:
raise NotImplementedError # :nocov:
def on_Sample(self, value):
if value in self.sample_cache:
return self.sample_cache[value]
if value.clocks == 0:
sample = value.value
else:
assert value.domain is not None
sampled_name, sampled_reset = self._name_reset(value.value)
name = "$sample${}${}${}".format(sampled_name, value.domain, value.clocks)
sample = Signal.like(value.value, name=name, reset_less=True, reset=sampled_reset)
prev_sample = self.on_Sample(Sample(value.value, value.clocks - 1, value.domain))
if value.domain not in self.sample_stmts:
self.sample_stmts[value.domain] = []
self.sample_stmts[value.domain].append(sample.eq(prev_sample))
self.sample_cache[value] = sample
return sample
def on_fragment(self, fragment):
new_fragment = super().on_fragment(fragment)
for domain, stmts in self.sample_stmts.items():
new_fragment.add_statements(stmts)
for stmt in stmts:
new_fragment.add_driver(stmt.lhs, domain)
return new_fragment
class SwitchCleaner(StatementVisitor):
def on_Assign(self, stmt):
return stmt

View file

@ -170,6 +170,52 @@ class DomainLowererTestCase(FHDLTestCase):
DomainLowerer({"sync": sync})(f)
class SampleLowererTestCase(FHDLTestCase):
def setUp(self):
self.i = Signal()
self.o1 = Signal()
self.o2 = Signal()
self.o3 = Signal()
def test_lower_signal(self):
f = Fragment()
f.add_statements(
self.o1.eq(Sample(self.i, 2, "sync")),
self.o2.eq(Sample(self.i, 1, "sync")),
self.o3.eq(Sample(self.i, 1, "pix")),
)
f = SampleLowerer()(f)
self.assertRepr(f.statements, """
(
(eq (sig o1) (sig $sample$s$i$sync$2))
(eq (sig o2) (sig $sample$s$i$sync$1))
(eq (sig o3) (sig $sample$s$i$pix$1))
(eq (sig $sample$s$i$sync$1) (sig i))
(eq (sig $sample$s$i$sync$2) (sig $sample$s$i$sync$1))
(eq (sig $sample$s$i$pix$1) (sig i))
)
""")
self.assertEqual(len(f.drivers["sync"]), 2)
self.assertEqual(len(f.drivers["pix"]), 1)
def test_lower_const(self):
f = Fragment()
f.add_statements(
self.o1.eq(Sample(1, 2, "sync")),
)
f = SampleLowerer()(f)
self.assertRepr(f.statements, """
(
(eq (sig o1) (sig $sample$c$1$sync$2))
(eq (sig $sample$c$1$sync$1) (const 1'd1))
(eq (sig $sample$c$1$sync$2) (sig $sample$c$1$sync$1))
)
""")
self.assertEqual(len(f.drivers["sync"]), 2)
class SwitchCleanerTestCase(FHDLTestCase):
def test_clean(self):
a = Signal()