
A property statement that is created but not added to a module is virtually always a serious bug, since it can make formal verification pass when it should not. Therefore, add a warning to it, similar to UnusedElaboratable. Doing this to all statements is possible, but many temporary ones are created internally by nMigen, and the extensive changes required to remove false positives are likely not worth the true positives. We can revisit this in the future. Fixes #303.
755 lines
25 KiB
Python
755 lines
25 KiB
Python
from abc import ABCMeta, abstractmethod
|
|
from collections import OrderedDict
|
|
from collections.abc import Iterable
|
|
|
|
from .._utils import flatten, deprecated
|
|
from .. import tracer
|
|
from .ast import *
|
|
from .ast import _StatementList
|
|
from .cd import *
|
|
from .ir import *
|
|
from .rec import *
|
|
|
|
|
|
__all__ = ["ValueVisitor", "ValueTransformer",
|
|
"StatementVisitor", "StatementTransformer",
|
|
"FragmentTransformer",
|
|
"TransformedElaboratable",
|
|
"DomainCollector", "DomainRenamer", "DomainLowerer",
|
|
"SampleDomainInjector", "SampleLowerer",
|
|
"SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
|
|
"ResetInserter", "EnableInserter"]
|
|
|
|
|
|
class ValueVisitor(metaclass=ABCMeta):
|
|
@abstractmethod
|
|
def on_Const(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_AnyConst(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_AnySeq(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Signal(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Record(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_ClockSignal(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_ResetSignal(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Operator(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Slice(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Part(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Cat(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Repl(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_ArrayProxy(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Sample(self, value):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Initial(self, value):
|
|
pass # :nocov:
|
|
|
|
def on_unknown_value(self, value):
|
|
raise TypeError("Cannot transform value {!r}".format(value)) # :nocov:
|
|
|
|
def replace_value_src_loc(self, value, new_value):
|
|
return True
|
|
|
|
def on_value(self, value):
|
|
if type(value) is Const:
|
|
new_value = self.on_Const(value)
|
|
elif type(value) is AnyConst:
|
|
new_value = self.on_AnyConst(value)
|
|
elif type(value) is AnySeq:
|
|
new_value = self.on_AnySeq(value)
|
|
elif isinstance(value, Signal):
|
|
# Uses `isinstance()` and not `type() is` because nmigen.compat requires it.
|
|
new_value = self.on_Signal(value)
|
|
elif isinstance(value, Record):
|
|
# Uses `isinstance()` and not `type() is` to allow inheriting from Record.
|
|
new_value = self.on_Record(value)
|
|
elif type(value) is ClockSignal:
|
|
new_value = self.on_ClockSignal(value)
|
|
elif type(value) is ResetSignal:
|
|
new_value = self.on_ResetSignal(value)
|
|
elif type(value) is Operator:
|
|
new_value = self.on_Operator(value)
|
|
elif type(value) is Slice:
|
|
new_value = self.on_Slice(value)
|
|
elif type(value) is Part:
|
|
new_value = self.on_Part(value)
|
|
elif type(value) is Cat:
|
|
new_value = self.on_Cat(value)
|
|
elif type(value) is Repl:
|
|
new_value = self.on_Repl(value)
|
|
elif type(value) is ArrayProxy:
|
|
new_value = self.on_ArrayProxy(value)
|
|
elif type(value) is Sample:
|
|
new_value = self.on_Sample(value)
|
|
elif type(value) is Initial:
|
|
new_value = self.on_Initial(value)
|
|
elif isinstance(value, UserValue):
|
|
# Uses `isinstance()` and not `type() is` to allow inheriting.
|
|
new_value = self.on_value(value._lazy_lower())
|
|
else:
|
|
new_value = self.on_unknown_value(value)
|
|
if isinstance(new_value, Value) and self.replace_value_src_loc(value, new_value):
|
|
new_value.src_loc = value.src_loc
|
|
return new_value
|
|
|
|
def __call__(self, value):
|
|
return self.on_value(value)
|
|
|
|
|
|
class ValueTransformer(ValueVisitor):
|
|
def on_Const(self, value):
|
|
return value
|
|
|
|
def on_AnyConst(self, value):
|
|
return value
|
|
|
|
def on_AnySeq(self, value):
|
|
return value
|
|
|
|
def on_Signal(self, value):
|
|
return value
|
|
|
|
def on_Record(self, value):
|
|
return value
|
|
|
|
def on_ClockSignal(self, value):
|
|
return value
|
|
|
|
def on_ResetSignal(self, value):
|
|
return value
|
|
|
|
def on_Operator(self, value):
|
|
return Operator(value.operator, [self.on_value(o) for o in value.operands])
|
|
|
|
def on_Slice(self, value):
|
|
return Slice(self.on_value(value.value), value.start, value.stop)
|
|
|
|
def on_Part(self, value):
|
|
return Part(self.on_value(value.value), self.on_value(value.offset),
|
|
value.width, value.stride)
|
|
|
|
def on_Cat(self, value):
|
|
return Cat(self.on_value(o) for o in value.parts)
|
|
|
|
def on_Repl(self, value):
|
|
return Repl(self.on_value(value.value), value.count)
|
|
|
|
def on_ArrayProxy(self, value):
|
|
return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
|
|
self.on_value(value.index))
|
|
|
|
def on_Sample(self, value):
|
|
return Sample(self.on_value(value.value), value.clocks, value.domain)
|
|
|
|
def on_Initial(self, value):
|
|
return value
|
|
|
|
|
|
class StatementVisitor(metaclass=ABCMeta):
|
|
@abstractmethod
|
|
def on_Assign(self, stmt):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Assert(self, stmt):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Assume(self, stmt):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Cover(self, stmt):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_Switch(self, stmt):
|
|
pass # :nocov:
|
|
|
|
@abstractmethod
|
|
def on_statements(self, stmts):
|
|
pass # :nocov:
|
|
|
|
def on_unknown_statement(self, stmt):
|
|
raise TypeError("Cannot transform statement {!r}".format(stmt)) # :nocov:
|
|
|
|
def replace_statement_src_loc(self, stmt, new_stmt):
|
|
return True
|
|
|
|
def on_statement(self, stmt):
|
|
if type(stmt) is Assign:
|
|
new_stmt = self.on_Assign(stmt)
|
|
elif type(stmt) is Assert:
|
|
new_stmt = self.on_Assert(stmt)
|
|
elif type(stmt) is Assume:
|
|
new_stmt = self.on_Assume(stmt)
|
|
elif type(stmt) is Cover:
|
|
new_stmt = self.on_Cover(stmt)
|
|
elif isinstance(stmt, Switch):
|
|
# Uses `isinstance()` and not `type() is` because nmigen.compat requires it.
|
|
new_stmt = self.on_Switch(stmt)
|
|
elif isinstance(stmt, Iterable):
|
|
new_stmt = self.on_statements(stmt)
|
|
else:
|
|
new_stmt = self.on_unknown_statement(stmt)
|
|
if isinstance(new_stmt, Statement) and self.replace_statement_src_loc(stmt, new_stmt):
|
|
new_stmt.src_loc = stmt.src_loc
|
|
if isinstance(new_stmt, Switch) and isinstance(stmt, Switch):
|
|
new_stmt.case_src_locs = stmt.case_src_locs
|
|
if isinstance(new_stmt, Property):
|
|
new_stmt._MustUse__used = True
|
|
return new_stmt
|
|
|
|
def __call__(self, stmt):
|
|
return self.on_statement(stmt)
|
|
|
|
|
|
class StatementTransformer(StatementVisitor):
|
|
def on_value(self, value):
|
|
return value
|
|
|
|
def on_Assign(self, stmt):
|
|
return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
|
|
|
|
def on_Assert(self, stmt):
|
|
return Assert(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en)
|
|
|
|
def on_Assume(self, stmt):
|
|
return Assume(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en)
|
|
|
|
def on_Cover(self, stmt):
|
|
return Cover(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en)
|
|
|
|
def on_Switch(self, stmt):
|
|
cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
|
|
return Switch(self.on_value(stmt.test), cases)
|
|
|
|
def on_statements(self, stmts):
|
|
return _StatementList(flatten(self.on_statement(stmt) for stmt in stmts))
|
|
|
|
|
|
class FragmentTransformer:
|
|
def map_subfragments(self, fragment, new_fragment):
|
|
for subfragment, name in fragment.subfragments:
|
|
new_fragment.add_subfragment(self(subfragment), name)
|
|
|
|
def map_ports(self, fragment, new_fragment):
|
|
for port, dir in fragment.ports.items():
|
|
new_fragment.add_ports(port, dir=dir)
|
|
|
|
def map_named_ports(self, fragment, new_fragment):
|
|
if hasattr(self, "on_value"):
|
|
for name, (value, dir) in fragment.named_ports.items():
|
|
new_fragment.named_ports[name] = self.on_value(value), dir
|
|
else:
|
|
new_fragment.named_ports = OrderedDict(fragment.named_ports.items())
|
|
|
|
def map_domains(self, fragment, new_fragment):
|
|
for domain in fragment.iter_domains():
|
|
new_fragment.add_domains(fragment.domains[domain])
|
|
|
|
def map_statements(self, fragment, new_fragment):
|
|
if hasattr(self, "on_statement"):
|
|
new_fragment.add_statements(map(self.on_statement, fragment.statements))
|
|
else:
|
|
new_fragment.add_statements(fragment.statements)
|
|
|
|
def map_drivers(self, fragment, new_fragment):
|
|
for domain, signal in fragment.iter_drivers():
|
|
new_fragment.add_driver(signal, domain)
|
|
|
|
def on_fragment(self, fragment):
|
|
if isinstance(fragment, Instance):
|
|
new_fragment = Instance(fragment.type)
|
|
new_fragment.parameters = OrderedDict(fragment.parameters)
|
|
self.map_named_ports(fragment, new_fragment)
|
|
else:
|
|
new_fragment = Fragment()
|
|
new_fragment.flatten = fragment.flatten
|
|
new_fragment.attrs = OrderedDict(fragment.attrs)
|
|
self.map_ports(fragment, new_fragment)
|
|
self.map_subfragments(fragment, new_fragment)
|
|
self.map_domains(fragment, new_fragment)
|
|
self.map_statements(fragment, new_fragment)
|
|
self.map_drivers(fragment, new_fragment)
|
|
return new_fragment
|
|
|
|
def __call__(self, value, *, src_loc_at=0):
|
|
if isinstance(value, Fragment):
|
|
return self.on_fragment(value)
|
|
elif isinstance(value, TransformedElaboratable):
|
|
value._transforms_.append(self)
|
|
return value
|
|
elif hasattr(value, "elaborate"):
|
|
value = TransformedElaboratable(value, src_loc_at=1 + src_loc_at)
|
|
value._transforms_.append(self)
|
|
return value
|
|
else:
|
|
raise AttributeError("Object {!r} cannot be elaborated".format(value))
|
|
|
|
|
|
class TransformedElaboratable(Elaboratable):
|
|
def __init__(self, elaboratable, *, src_loc_at=0):
|
|
assert hasattr(elaboratable, "elaborate")
|
|
|
|
# Fields prefixed and suffixed with underscore to avoid as many conflicts with the inner
|
|
# object as possible, since we're forwarding attribute requests to it.
|
|
self._elaboratable_ = elaboratable
|
|
self._transforms_ = []
|
|
|
|
def __getattr__(self, attr):
|
|
return getattr(self._elaboratable_, attr)
|
|
|
|
def elaborate(self, platform):
|
|
fragment = Fragment.get(self._elaboratable_, platform)
|
|
for transform in self._transforms_:
|
|
fragment = transform(fragment)
|
|
return fragment
|
|
|
|
|
|
class DomainCollector(ValueVisitor, StatementVisitor):
|
|
def __init__(self):
|
|
self.used_domains = set()
|
|
self.defined_domains = set()
|
|
self._local_domains = set()
|
|
|
|
def _add_used_domain(self, domain_name):
|
|
if domain_name is None:
|
|
return
|
|
if domain_name in self._local_domains:
|
|
return
|
|
self.used_domains.add(domain_name)
|
|
|
|
def on_ignore(self, value):
|
|
pass
|
|
|
|
on_Const = on_ignore
|
|
on_AnyConst = on_ignore
|
|
on_AnySeq = on_ignore
|
|
on_Signal = on_ignore
|
|
|
|
def on_ClockSignal(self, value):
|
|
self._add_used_domain(value.domain)
|
|
|
|
def on_ResetSignal(self, value):
|
|
self._add_used_domain(value.domain)
|
|
|
|
on_Record = on_ignore
|
|
|
|
def on_Operator(self, value):
|
|
for o in value.operands:
|
|
self.on_value(o)
|
|
|
|
def on_Slice(self, value):
|
|
self.on_value(value.value)
|
|
|
|
def on_Part(self, value):
|
|
self.on_value(value.value)
|
|
self.on_value(value.offset)
|
|
|
|
def on_Cat(self, value):
|
|
for o in value.parts:
|
|
self.on_value(o)
|
|
|
|
def on_Repl(self, value):
|
|
self.on_value(value.value)
|
|
|
|
def on_ArrayProxy(self, value):
|
|
for elem in value._iter_as_values():
|
|
self.on_value(elem)
|
|
self.on_value(value.index)
|
|
|
|
def on_Sample(self, value):
|
|
self.on_value(value.value)
|
|
|
|
def on_Initial(self, value):
|
|
pass
|
|
|
|
def on_Assign(self, stmt):
|
|
self.on_value(stmt.lhs)
|
|
self.on_value(stmt.rhs)
|
|
|
|
def on_property(self, stmt):
|
|
self.on_value(stmt.test)
|
|
|
|
on_Assert = on_property
|
|
on_Assume = on_property
|
|
on_Cover = on_property
|
|
|
|
def on_Switch(self, stmt):
|
|
self.on_value(stmt.test)
|
|
for stmts in stmt.cases.values():
|
|
self.on_statement(stmts)
|
|
|
|
def on_statements(self, stmts):
|
|
for stmt in stmts:
|
|
self.on_statement(stmt)
|
|
|
|
def on_fragment(self, fragment):
|
|
if isinstance(fragment, Instance):
|
|
for name, (value, dir) in fragment.named_ports.items():
|
|
self.on_value(value)
|
|
|
|
old_local_domains, self._local_domains = self._local_domains, set(self._local_domains)
|
|
for domain_name, domain in fragment.domains.items():
|
|
if domain.local:
|
|
self._local_domains.add(domain_name)
|
|
else:
|
|
self.defined_domains.add(domain_name)
|
|
|
|
self.on_statements(fragment.statements)
|
|
for domain_name in fragment.drivers:
|
|
self._add_used_domain(domain_name)
|
|
for subfragment, name in fragment.subfragments:
|
|
self.on_fragment(subfragment)
|
|
|
|
self._local_domains = old_local_domains
|
|
|
|
def __call__(self, fragment):
|
|
self.on_fragment(fragment)
|
|
|
|
|
|
class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer):
|
|
def __init__(self, domain_map):
|
|
if isinstance(domain_map, str):
|
|
domain_map = {"sync": domain_map}
|
|
for src, dst in domain_map.items():
|
|
if src == "comb":
|
|
raise ValueError("Domain '{}' may not be renamed".format(src))
|
|
if dst == "comb":
|
|
raise ValueError("Domain '{}' may not be renamed to '{}'".format(src, dst))
|
|
self.domain_map = OrderedDict(domain_map)
|
|
|
|
def on_ClockSignal(self, value):
|
|
if value.domain in self.domain_map:
|
|
return ClockSignal(self.domain_map[value.domain])
|
|
return value
|
|
|
|
def on_ResetSignal(self, value):
|
|
if value.domain in self.domain_map:
|
|
return ResetSignal(self.domain_map[value.domain])
|
|
return value
|
|
|
|
def map_domains(self, fragment, new_fragment):
|
|
for domain in fragment.iter_domains():
|
|
cd = fragment.domains[domain]
|
|
if domain in self.domain_map:
|
|
if cd.name == domain:
|
|
# Rename the actual ClockDomain object.
|
|
cd.rename(self.domain_map[domain])
|
|
else:
|
|
assert cd.name == self.domain_map[domain]
|
|
new_fragment.add_domains(cd)
|
|
|
|
def map_drivers(self, fragment, new_fragment):
|
|
for domain, signals in fragment.drivers.items():
|
|
if domain in self.domain_map:
|
|
domain = self.domain_map[domain]
|
|
for signal in signals:
|
|
new_fragment.add_driver(self.on_value(signal), domain)
|
|
|
|
|
|
class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
|
|
def __init__(self, domains=None):
|
|
self.domains = domains
|
|
|
|
def _resolve(self, domain, context):
|
|
if domain not in self.domains:
|
|
raise DomainError("Signal {!r} refers to nonexistent domain '{}'"
|
|
.format(context, domain))
|
|
return self.domains[domain]
|
|
|
|
def map_drivers(self, fragment, new_fragment):
|
|
for domain, signal in fragment.iter_drivers():
|
|
new_fragment.add_driver(self.on_value(signal), domain)
|
|
|
|
def replace_value_src_loc(self, value, new_value):
|
|
return not isinstance(value, (ClockSignal, ResetSignal))
|
|
|
|
def on_ClockSignal(self, value):
|
|
domain = self._resolve(value.domain, value)
|
|
return domain.clk
|
|
|
|
def on_ResetSignal(self, value):
|
|
domain = self._resolve(value.domain, value)
|
|
if domain.rst is None:
|
|
if value.allow_reset_less:
|
|
return Const(0)
|
|
else:
|
|
raise DomainError("Signal {!r} refers to reset of reset-less domain '{}'"
|
|
.format(value, value.domain))
|
|
return domain.rst
|
|
|
|
def _insert_resets(self, fragment):
|
|
for domain_name, signals in fragment.drivers.items():
|
|
if domain_name is None:
|
|
continue
|
|
domain = fragment.domains[domain_name]
|
|
if domain.rst is None:
|
|
continue
|
|
stmts = [signal.eq(Const(signal.reset, signal.width))
|
|
for signal in signals if not signal.reset_less]
|
|
fragment.add_statements(Switch(domain.rst, {1: stmts}))
|
|
|
|
def on_fragment(self, fragment):
|
|
self.domains = fragment.domains
|
|
new_fragment = super().on_fragment(fragment)
|
|
self._insert_resets(new_fragment)
|
|
return new_fragment
|
|
|
|
|
|
class SampleDomainInjector(ValueTransformer, StatementTransformer):
|
|
def __init__(self, domain):
|
|
self.domain = domain
|
|
|
|
def on_Sample(self, value):
|
|
if value.domain is not None:
|
|
return value
|
|
return Sample(value.value, value.clocks, self.domain)
|
|
|
|
def __call__(self, stmts):
|
|
return self.on_statement(stmts)
|
|
|
|
|
|
class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
|
|
def __init__(self):
|
|
self.initial = None
|
|
self.sample_cache = None
|
|
self.sample_stmts = None
|
|
|
|
def _name_reset(self, value):
|
|
if isinstance(value, Const):
|
|
return "c${}".format(value.value), value.value
|
|
elif isinstance(value, Signal):
|
|
return "s${}".format(value.name), value.reset
|
|
elif isinstance(value, ClockSignal):
|
|
return "clk", 0
|
|
elif isinstance(value, ResetSignal):
|
|
return "rst", 1
|
|
elif isinstance(value, Initial):
|
|
return "init", 0 # Past(Initial()) produces 0, 1, 0, 0, ...
|
|
else:
|
|
raise NotImplementedError # :nocov:
|
|
|
|
def on_Sample(self, value):
|
|
if value in self.sample_cache:
|
|
return self.sample_cache[value]
|
|
|
|
sampled_value = self.on_value(value.value)
|
|
if value.clocks == 0:
|
|
sample = sampled_value
|
|
else:
|
|
assert value.domain is not None
|
|
sampled_name, sampled_reset = self._name_reset(value.value)
|
|
name = "$sample${}${}${}".format(sampled_name, value.domain, value.clocks)
|
|
sample = Signal.like(value.value, name=name, reset_less=True, reset=sampled_reset)
|
|
sample.attrs["nmigen.sample_reg"] = True
|
|
|
|
prev_sample = self.on_Sample(Sample(sampled_value, value.clocks - 1, value.domain))
|
|
if value.domain not in self.sample_stmts:
|
|
self.sample_stmts[value.domain] = []
|
|
self.sample_stmts[value.domain].append(sample.eq(prev_sample))
|
|
|
|
self.sample_cache[value] = sample
|
|
return sample
|
|
|
|
def on_Initial(self, value):
|
|
if self.initial is None:
|
|
self.initial = Signal(name="init")
|
|
return self.initial
|
|
|
|
def map_statements(self, fragment, new_fragment):
|
|
self.initial = None
|
|
self.sample_cache = ValueDict()
|
|
self.sample_stmts = OrderedDict()
|
|
new_fragment.add_statements(map(self.on_statement, fragment.statements))
|
|
for domain, stmts in self.sample_stmts.items():
|
|
new_fragment.add_statements(stmts)
|
|
for stmt in stmts:
|
|
new_fragment.add_driver(stmt.lhs, domain)
|
|
if self.initial is not None:
|
|
new_fragment.add_subfragment(Instance("$initstate", o_Y=self.initial))
|
|
|
|
|
|
class SwitchCleaner(StatementVisitor):
|
|
def on_ignore(self, stmt):
|
|
return stmt
|
|
|
|
on_Assign = on_ignore
|
|
on_Assert = on_ignore
|
|
on_Assume = on_ignore
|
|
on_Cover = on_ignore
|
|
|
|
def on_Switch(self, stmt):
|
|
cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
|
|
if any(len(s) for s in cases.values()):
|
|
return Switch(stmt.test, cases)
|
|
|
|
def on_statements(self, stmts):
|
|
stmts = flatten(self.on_statement(stmt) for stmt in stmts)
|
|
return _StatementList(stmt for stmt in stmts if stmt is not None)
|
|
|
|
|
|
class LHSGroupAnalyzer(StatementVisitor):
|
|
def __init__(self):
|
|
self.signals = SignalDict()
|
|
self.unions = OrderedDict()
|
|
|
|
def find(self, signal):
|
|
if signal not in self.signals:
|
|
self.signals[signal] = len(self.signals)
|
|
group = self.signals[signal]
|
|
while group in self.unions:
|
|
group = self.unions[group]
|
|
self.signals[signal] = group
|
|
return group
|
|
|
|
def unify(self, root, *leaves):
|
|
root_group = self.find(root)
|
|
for leaf in leaves:
|
|
leaf_group = self.find(leaf)
|
|
if root_group == leaf_group:
|
|
continue
|
|
self.unions[leaf_group] = root_group
|
|
|
|
def groups(self):
|
|
groups = OrderedDict()
|
|
for signal in self.signals:
|
|
group = self.find(signal)
|
|
if group not in groups:
|
|
groups[group] = SignalSet()
|
|
groups[group].add(signal)
|
|
return groups
|
|
|
|
def on_Assign(self, stmt):
|
|
lhs_signals = stmt._lhs_signals()
|
|
if lhs_signals:
|
|
self.unify(*stmt._lhs_signals())
|
|
|
|
def on_property(self, stmt):
|
|
lhs_signals = stmt._lhs_signals()
|
|
if lhs_signals:
|
|
self.unify(*stmt._lhs_signals())
|
|
|
|
on_Assert = on_property
|
|
on_Assume = on_property
|
|
on_Cover = on_property
|
|
|
|
def on_Switch(self, stmt):
|
|
for case_stmts in stmt.cases.values():
|
|
self.on_statements(case_stmts)
|
|
|
|
def on_statements(self, stmts):
|
|
for stmt in stmts:
|
|
self.on_statement(stmt)
|
|
|
|
def __call__(self, stmts):
|
|
self.on_statements(stmts)
|
|
return self.groups()
|
|
|
|
|
|
class LHSGroupFilter(SwitchCleaner):
|
|
def __init__(self, signals):
|
|
self.signals = signals
|
|
|
|
def on_Assign(self, stmt):
|
|
# The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together
|
|
# on LHS are a part of the same group, so it is sufficient to check any of them.
|
|
lhs_signals = stmt.lhs._lhs_signals()
|
|
if lhs_signals:
|
|
any_lhs_signal = next(iter(lhs_signals))
|
|
if any_lhs_signal in self.signals:
|
|
return stmt
|
|
|
|
def on_property(self, stmt):
|
|
any_lhs_signal = next(iter(stmt._lhs_signals()))
|
|
if any_lhs_signal in self.signals:
|
|
return stmt
|
|
|
|
on_Assert = on_property
|
|
on_Assume = on_property
|
|
on_Cover = on_property
|
|
|
|
|
|
class _ControlInserter(FragmentTransformer):
|
|
def __init__(self, controls):
|
|
self.src_loc = None
|
|
if isinstance(controls, Value):
|
|
controls = {"sync": controls}
|
|
self.controls = OrderedDict(controls)
|
|
|
|
def on_fragment(self, fragment):
|
|
new_fragment = super().on_fragment(fragment)
|
|
for domain, signals in fragment.drivers.items():
|
|
if domain is None or domain not in self.controls:
|
|
continue
|
|
self._insert_control(new_fragment, domain, signals)
|
|
return new_fragment
|
|
|
|
def _insert_control(self, fragment, domain, signals):
|
|
raise NotImplementedError # :nocov:
|
|
|
|
def __call__(self, value, *, src_loc_at=0):
|
|
self.src_loc = tracer.get_src_loc(src_loc_at=src_loc_at)
|
|
return super().__call__(value, src_loc_at=1 + src_loc_at)
|
|
|
|
|
|
class ResetInserter(_ControlInserter):
|
|
def _insert_control(self, fragment, domain, signals):
|
|
stmts = [s.eq(Const(s.reset, s.width)) for s in signals if not s.reset_less]
|
|
fragment.add_statements(Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc))
|
|
|
|
|
|
class EnableInserter(_ControlInserter):
|
|
def _insert_control(self, fragment, domain, signals):
|
|
stmts = [s.eq(s) for s in signals]
|
|
fragment.add_statements(Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc))
|
|
|
|
def on_fragment(self, fragment):
|
|
new_fragment = super().on_fragment(fragment)
|
|
if isinstance(new_fragment, Instance) and new_fragment.type in ("$memrd", "$memwr"):
|
|
clk_port, clk_dir = new_fragment.named_ports["CLK"]
|
|
if isinstance(clk_port, ClockSignal) and clk_port.domain in self.controls:
|
|
en_port, en_dir = new_fragment.named_ports["EN"]
|
|
en_port = Mux(self.controls[clk_port.domain], en_port, Const(0, len(en_port)))
|
|
new_fragment.named_ports["EN"] = en_port, en_dir
|
|
return new_fragment
|