hdl.xfrm: implement LHSGroupAnalyzer.

This commit is contained in:
whitequark 2018-12-22 06:50:32 +00:00
parent 98a9744be4
commit ae0cb48fbb
2 changed files with 97 additions and 4 deletions

View file

@ -12,7 +12,9 @@ from .ir import *
__all__ = ["ValueVisitor", "ValueTransformer", __all__ = ["ValueVisitor", "ValueTransformer",
"StatementVisitor", "StatementTransformer", "StatementVisitor", "StatementTransformer",
"FragmentTransformer", "FragmentTransformer",
"DomainRenamer", "DomainLowerer", "ResetInserter", "CEInserter"] "DomainRenamer", "DomainLowerer",
"LHSGroupAnalyzer",
"ResetInserter", "CEInserter"]
class ValueVisitor(metaclass=ABCMeta): class ValueVisitor(metaclass=ABCMeta):
@ -134,7 +136,7 @@ class StatementVisitor(metaclass=ABCMeta):
pass # :nocov: pass # :nocov:
@abstractmethod @abstractmethod
def on_statements(self, stmt): def on_statements(self, stmts):
pass # :nocov: pass # :nocov:
def on_unknown_statement(self, stmt): def on_unknown_statement(self, stmt):
@ -166,8 +168,8 @@ class StatementTransformer(StatementVisitor):
cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items()) cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
return Switch(self.on_value(stmt.test), cases) return Switch(self.on_value(stmt.test), cases)
def on_statements(self, stmt): def on_statements(self, stmts):
return _StatementList(flatten(self.on_statement(stmt) for stmt in stmt)) return _StatementList(flatten(self.on_statement(stmt) for stmt in stmts))
class FragmentTransformer: class FragmentTransformer:
@ -278,6 +280,51 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
return cd.rst return cd.rst
class LHSGroupAnalyzer(StatementVisitor):
def __init__(self):
self.signals = SignalDict()
self.unions = OrderedDict()
def find(self, signal):
if signal not in self.signals:
self.signals[signal] = len(self.signals)
group = self.signals[signal]
while group in self.unions:
group = self.unions[group]
self.signals[signal] = group
return group
def unify(self, root, *leaves):
root_group = self.find(root)
for leaf in leaves:
leaf_group = self.find(leaf)
self.unions[leaf_group] = root_group
def groups(self):
groups = OrderedDict()
for signal in self.signals:
group = self.find(signal)
if group not in groups:
groups[group] = SignalSet()
groups[group].add(signal)
return groups
def on_Assign(self, stmt):
self.unify(*stmt._lhs_signals())
def on_Switch(self, stmt):
for case_stmts in stmt.cases.values():
self.on_statements(case_stmts)
def on_statements(self, stmts):
for stmt in stmts:
self.on_statement(stmt)
def __call__(self, stmts):
self.on_statements(stmts)
return self.groups()
class _ControlInserter(FragmentTransformer): class _ControlInserter(FragmentTransformer):
def __init__(self, controls): def __init__(self, controls):
if isinstance(controls, Value): if isinstance(controls, Value):

View file

@ -158,6 +158,52 @@ class DomainLowererTestCase(FHDLTestCase):
DomainLowerer({"sync": sync})(f) DomainLowerer({"sync": sync})(f)
class LHSGroupAnalyzerTestCase(FHDLTestCase):
def test_no_group_unrelated(self):
a = Signal()
b = Signal()
stmts = [
a.eq(0),
b.eq(0),
]
groups = LHSGroupAnalyzer()(stmts)
self.assertEqual(list(groups.values()), [
SignalSet((a,)),
SignalSet((b,)),
])
def test_group_related(self):
a = Signal()
b = Signal()
stmts = [
a.eq(0),
Cat(a, b).eq(0),
]
groups = LHSGroupAnalyzer()(stmts)
self.assertEqual(list(groups.values()), [
SignalSet((a, b)),
])
def test_switch(self):
a = Signal()
b = Signal()
stmts = [
a.eq(0),
Switch(a, {
1: b.eq(0),
})
]
groups = LHSGroupAnalyzer()(stmts)
self.assertEqual(list(groups.values()), [
SignalSet((a,)),
SignalSet((b,)),
])
class ResetInserterTestCase(FHDLTestCase): class ResetInserterTestCase(FHDLTestCase):
def setUp(self): def setUp(self):
self.s1 = Signal() self.s1 = Signal()