hdl.xfrm: implement LHSGroupAnalyzer.
This commit is contained in:
parent
98a9744be4
commit
ae0cb48fbb
|
@ -12,7 +12,9 @@ from .ir import *
|
||||||
__all__ = ["ValueVisitor", "ValueTransformer",
|
__all__ = ["ValueVisitor", "ValueTransformer",
|
||||||
"StatementVisitor", "StatementTransformer",
|
"StatementVisitor", "StatementTransformer",
|
||||||
"FragmentTransformer",
|
"FragmentTransformer",
|
||||||
"DomainRenamer", "DomainLowerer", "ResetInserter", "CEInserter"]
|
"DomainRenamer", "DomainLowerer",
|
||||||
|
"LHSGroupAnalyzer",
|
||||||
|
"ResetInserter", "CEInserter"]
|
||||||
|
|
||||||
|
|
||||||
class ValueVisitor(metaclass=ABCMeta):
|
class ValueVisitor(metaclass=ABCMeta):
|
||||||
|
@ -134,7 +136,7 @@ class StatementVisitor(metaclass=ABCMeta):
|
||||||
pass # :nocov:
|
pass # :nocov:
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_statements(self, stmt):
|
def on_statements(self, stmts):
|
||||||
pass # :nocov:
|
pass # :nocov:
|
||||||
|
|
||||||
def on_unknown_statement(self, stmt):
|
def on_unknown_statement(self, stmt):
|
||||||
|
@ -166,8 +168,8 @@ class StatementTransformer(StatementVisitor):
|
||||||
cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
|
cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
|
||||||
return Switch(self.on_value(stmt.test), cases)
|
return Switch(self.on_value(stmt.test), cases)
|
||||||
|
|
||||||
def on_statements(self, stmt):
|
def on_statements(self, stmts):
|
||||||
return _StatementList(flatten(self.on_statement(stmt) for stmt in stmt))
|
return _StatementList(flatten(self.on_statement(stmt) for stmt in stmts))
|
||||||
|
|
||||||
|
|
||||||
class FragmentTransformer:
|
class FragmentTransformer:
|
||||||
|
@ -278,6 +280,51 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
|
||||||
return cd.rst
|
return cd.rst
|
||||||
|
|
||||||
|
|
||||||
|
class LHSGroupAnalyzer(StatementVisitor):
|
||||||
|
def __init__(self):
|
||||||
|
self.signals = SignalDict()
|
||||||
|
self.unions = OrderedDict()
|
||||||
|
|
||||||
|
def find(self, signal):
|
||||||
|
if signal not in self.signals:
|
||||||
|
self.signals[signal] = len(self.signals)
|
||||||
|
group = self.signals[signal]
|
||||||
|
while group in self.unions:
|
||||||
|
group = self.unions[group]
|
||||||
|
self.signals[signal] = group
|
||||||
|
return group
|
||||||
|
|
||||||
|
def unify(self, root, *leaves):
|
||||||
|
root_group = self.find(root)
|
||||||
|
for leaf in leaves:
|
||||||
|
leaf_group = self.find(leaf)
|
||||||
|
self.unions[leaf_group] = root_group
|
||||||
|
|
||||||
|
def groups(self):
|
||||||
|
groups = OrderedDict()
|
||||||
|
for signal in self.signals:
|
||||||
|
group = self.find(signal)
|
||||||
|
if group not in groups:
|
||||||
|
groups[group] = SignalSet()
|
||||||
|
groups[group].add(signal)
|
||||||
|
return groups
|
||||||
|
|
||||||
|
def on_Assign(self, stmt):
|
||||||
|
self.unify(*stmt._lhs_signals())
|
||||||
|
|
||||||
|
def on_Switch(self, stmt):
|
||||||
|
for case_stmts in stmt.cases.values():
|
||||||
|
self.on_statements(case_stmts)
|
||||||
|
|
||||||
|
def on_statements(self, stmts):
|
||||||
|
for stmt in stmts:
|
||||||
|
self.on_statement(stmt)
|
||||||
|
|
||||||
|
def __call__(self, stmts):
|
||||||
|
self.on_statements(stmts)
|
||||||
|
return self.groups()
|
||||||
|
|
||||||
|
|
||||||
class _ControlInserter(FragmentTransformer):
|
class _ControlInserter(FragmentTransformer):
|
||||||
def __init__(self, controls):
|
def __init__(self, controls):
|
||||||
if isinstance(controls, Value):
|
if isinstance(controls, Value):
|
||||||
|
|
|
@ -158,6 +158,52 @@ class DomainLowererTestCase(FHDLTestCase):
|
||||||
DomainLowerer({"sync": sync})(f)
|
DomainLowerer({"sync": sync})(f)
|
||||||
|
|
||||||
|
|
||||||
|
class LHSGroupAnalyzerTestCase(FHDLTestCase):
|
||||||
|
def test_no_group_unrelated(self):
|
||||||
|
a = Signal()
|
||||||
|
b = Signal()
|
||||||
|
stmts = [
|
||||||
|
a.eq(0),
|
||||||
|
b.eq(0),
|
||||||
|
]
|
||||||
|
|
||||||
|
groups = LHSGroupAnalyzer()(stmts)
|
||||||
|
self.assertEqual(list(groups.values()), [
|
||||||
|
SignalSet((a,)),
|
||||||
|
SignalSet((b,)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_group_related(self):
|
||||||
|
a = Signal()
|
||||||
|
b = Signal()
|
||||||
|
stmts = [
|
||||||
|
a.eq(0),
|
||||||
|
Cat(a, b).eq(0),
|
||||||
|
]
|
||||||
|
|
||||||
|
groups = LHSGroupAnalyzer()(stmts)
|
||||||
|
self.assertEqual(list(groups.values()), [
|
||||||
|
SignalSet((a, b)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_switch(self):
|
||||||
|
a = Signal()
|
||||||
|
b = Signal()
|
||||||
|
stmts = [
|
||||||
|
a.eq(0),
|
||||||
|
Switch(a, {
|
||||||
|
1: b.eq(0),
|
||||||
|
})
|
||||||
|
]
|
||||||
|
|
||||||
|
groups = LHSGroupAnalyzer()(stmts)
|
||||||
|
self.assertEqual(list(groups.values()), [
|
||||||
|
SignalSet((a,)),
|
||||||
|
SignalSet((b,)),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ResetInserterTestCase(FHDLTestCase):
|
class ResetInserterTestCase(FHDLTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.s1 = Signal()
|
self.s1 = Signal()
|
||||||
|
|
Loading…
Reference in a new issue