hdl.xfrm: add SampleLowerer.
This commit is contained in:
parent
b3de114d67
commit
198efcad31
|
@ -74,6 +74,15 @@ normalize = Const.normalize
|
|||
|
||||
|
||||
class _ValueCompiler(ValueVisitor):
|
||||
def on_AnyConst(self, value):
|
||||
raise NotImplementedError # :nocov:
|
||||
|
||||
def on_AnySeq(self, value):
|
||||
raise NotImplementedError # :nocov:
|
||||
|
||||
def on_Sample(self, value):
|
||||
raise NotImplementedError # :nocov:
|
||||
|
||||
def on_Record(self, value):
|
||||
return self(Cat(value.fields.values()))
|
||||
|
||||
|
@ -87,12 +96,6 @@ class _RHSValueCompiler(_ValueCompiler):
|
|||
def on_Const(self, value):
|
||||
return lambda state: value.value
|
||||
|
||||
def on_AnyConst(self, value):
|
||||
raise NotImplementedError # :nocov:
|
||||
|
||||
def on_AnySeq(self, value):
|
||||
raise NotImplementedError # :nocov:
|
||||
|
||||
def on_Signal(self, value):
|
||||
if self.sensitivity is not None:
|
||||
self.sensitivity.add(value)
|
||||
|
@ -225,12 +228,6 @@ class _LHSValueCompiler(_ValueCompiler):
|
|||
def on_Const(self, value):
|
||||
raise TypeError # :nocov:
|
||||
|
||||
def on_AnyConst(self, value):
|
||||
raise TypeError # :nocov:
|
||||
|
||||
def on_AnySeq(self, value):
|
||||
raise TypeError # :nocov:
|
||||
|
||||
def on_Signal(self, value):
|
||||
shape = value.shape()
|
||||
value_slot = self.signal_slots[value]
|
||||
|
|
|
@ -308,6 +308,9 @@ class _ValueCompiler(xfrm.ValueVisitor):
|
|||
def on_ResetSignal(self, value):
|
||||
raise NotImplementedError # :nocov:
|
||||
|
||||
def on_Sample(self, value):
|
||||
raise NotImplementedError # :nocov:
|
||||
|
||||
def on_Record(self, value):
|
||||
return self(ast.Cat(value.fields.values()))
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from .rec import *
|
|||
__all__ = ["ValueVisitor", "ValueTransformer",
|
||||
"StatementVisitor", "StatementTransformer",
|
||||
"FragmentTransformer",
|
||||
"DomainRenamer", "DomainLowerer",
|
||||
"DomainRenamer", "DomainLowerer", "SampleLowerer",
|
||||
"SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
|
||||
"ResetInserter", "CEInserter"]
|
||||
|
||||
|
@ -71,6 +71,10 @@ class ValueVisitor(metaclass=ABCMeta):
|
|||
def on_ArrayProxy(self, value):
|
||||
pass # :nocov:
|
||||
|
||||
@abstractmethod
|
||||
def on_Sample(self, value):
|
||||
pass # :nocov:
|
||||
|
||||
def on_unknown_value(self, value):
|
||||
raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov:
|
||||
|
||||
|
@ -102,6 +106,8 @@ class ValueVisitor(metaclass=ABCMeta):
|
|||
new_value = self.on_Repl(value)
|
||||
elif type(value) is ArrayProxy:
|
||||
new_value = self.on_ArrayProxy(value)
|
||||
elif type(value) is Sample:
|
||||
new_value = self.on_Sample(value)
|
||||
else:
|
||||
new_value = self.on_unknown_value(value)
|
||||
if isinstance(new_value, Value):
|
||||
|
@ -153,6 +159,9 @@ class ValueTransformer(ValueVisitor):
|
|||
return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
|
||||
self.on_value(value.index))
|
||||
|
||||
def on_Sample(self, value):
|
||||
return Sample(self.on_value(value.value), value.clocks, value.domain)
|
||||
|
||||
|
||||
class StatementVisitor(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
|
@ -331,6 +340,48 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
|
|||
return cd.rst
|
||||
|
||||
|
||||
class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
|
||||
def __init__(self):
|
||||
self.sample_cache = ValueDict()
|
||||
self.sample_stmts = OrderedDict()
|
||||
|
||||
def _name_reset(self, value):
|
||||
if isinstance(value, Const):
|
||||
return "c${}".format(value.value), value.value
|
||||
elif isinstance(value, Signal):
|
||||
return "s${}".format(value.name), value.reset
|
||||
else:
|
||||
raise NotImplementedError # :nocov:
|
||||
|
||||
def on_Sample(self, value):
|
||||
if value in self.sample_cache:
|
||||
return self.sample_cache[value]
|
||||
|
||||
if value.clocks == 0:
|
||||
sample = value.value
|
||||
else:
|
||||
assert value.domain is not None
|
||||
sampled_name, sampled_reset = self._name_reset(value.value)
|
||||
name = "$sample${}${}${}".format(sampled_name, value.domain, value.clocks)
|
||||
sample = Signal.like(value.value, name=name, reset_less=True, reset=sampled_reset)
|
||||
|
||||
prev_sample = self.on_Sample(Sample(value.value, value.clocks - 1, value.domain))
|
||||
if value.domain not in self.sample_stmts:
|
||||
self.sample_stmts[value.domain] = []
|
||||
self.sample_stmts[value.domain].append(sample.eq(prev_sample))
|
||||
|
||||
self.sample_cache[value] = sample
|
||||
return sample
|
||||
|
||||
def on_fragment(self, fragment):
|
||||
new_fragment = super().on_fragment(fragment)
|
||||
for domain, stmts in self.sample_stmts.items():
|
||||
new_fragment.add_statements(stmts)
|
||||
for stmt in stmts:
|
||||
new_fragment.add_driver(stmt.lhs, domain)
|
||||
return new_fragment
|
||||
|
||||
|
||||
class SwitchCleaner(StatementVisitor):
|
||||
def on_Assign(self, stmt):
|
||||
return stmt
|
||||
|
|
|
@ -170,6 +170,52 @@ class DomainLowererTestCase(FHDLTestCase):
|
|||
DomainLowerer({"sync": sync})(f)
|
||||
|
||||
|
||||
class SampleLowererTestCase(FHDLTestCase):
|
||||
def setUp(self):
|
||||
self.i = Signal()
|
||||
self.o1 = Signal()
|
||||
self.o2 = Signal()
|
||||
self.o3 = Signal()
|
||||
|
||||
def test_lower_signal(self):
|
||||
f = Fragment()
|
||||
f.add_statements(
|
||||
self.o1.eq(Sample(self.i, 2, "sync")),
|
||||
self.o2.eq(Sample(self.i, 1, "sync")),
|
||||
self.o3.eq(Sample(self.i, 1, "pix")),
|
||||
)
|
||||
|
||||
f = SampleLowerer()(f)
|
||||
self.assertRepr(f.statements, """
|
||||
(
|
||||
(eq (sig o1) (sig $sample$s$i$sync$2))
|
||||
(eq (sig o2) (sig $sample$s$i$sync$1))
|
||||
(eq (sig o3) (sig $sample$s$i$pix$1))
|
||||
(eq (sig $sample$s$i$sync$1) (sig i))
|
||||
(eq (sig $sample$s$i$sync$2) (sig $sample$s$i$sync$1))
|
||||
(eq (sig $sample$s$i$pix$1) (sig i))
|
||||
)
|
||||
""")
|
||||
self.assertEqual(len(f.drivers["sync"]), 2)
|
||||
self.assertEqual(len(f.drivers["pix"]), 1)
|
||||
|
||||
def test_lower_const(self):
|
||||
f = Fragment()
|
||||
f.add_statements(
|
||||
self.o1.eq(Sample(1, 2, "sync")),
|
||||
)
|
||||
|
||||
f = SampleLowerer()(f)
|
||||
self.assertRepr(f.statements, """
|
||||
(
|
||||
(eq (sig o1) (sig $sample$c$1$sync$2))
|
||||
(eq (sig $sample$c$1$sync$1) (const 1'd1))
|
||||
(eq (sig $sample$c$1$sync$2) (sig $sample$c$1$sync$1))
|
||||
)
|
||||
""")
|
||||
self.assertEqual(len(f.drivers["sync"]), 2)
|
||||
|
||||
|
||||
class SwitchCleanerTestCase(FHDLTestCase):
|
||||
def test_clean(self):
|
||||
a = Signal()
|
||||
|
|
Loading…
Reference in a new issue