hdl, back: add and use SignalSet/SignalDict.

This commit is contained in:
whitequark 2018-12-17 17:21:12 +00:00
parent 8c4de99c0d
commit 8d1639a5a8
8 changed files with 87 additions and 54 deletions

View file

@ -20,10 +20,10 @@ class _State:
__slots__ = ("curr", "curr_dirty", "next", "next_dirty") __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
def __init__(self): def __init__(self):
self.curr = ValueDict() self.curr = SignalDict()
self.next = ValueDict() self.next = SignalDict()
self.curr_dirty = ValueSet() self.curr_dirty = SignalSet()
self.next_dirty = ValueSet() self.next_dirty = SignalSet()
def set(self, signal, value): def set(self, signal, value):
assert isinstance(value, int) assert isinstance(value, int)
@ -236,7 +236,7 @@ class _LHSValueCompiler(AbstractValueTransformer):
class _StatementCompiler(AbstractStatementTransformer): class _StatementCompiler(AbstractStatementTransformer):
def __init__(self): def __init__(self):
self.sensitivity = ValueSet() self.sensitivity = SignalSet()
self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs") self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs")
self.lrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="lhs") self.lrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="lhs")
self.lhs_compiler = _LHSValueCompiler(self.lrhs_compiler) self.lhs_compiler = _LHSValueCompiler(self.lrhs_compiler)
@ -284,13 +284,13 @@ class Simulator:
self._fragment = fragment self._fragment = fragment
self._domains = dict() # str/domain -> ClockDomain self._domains = dict() # str/domain -> ClockDomain
self._domain_triggers = ValueDict() # Signal -> str/domain self._domain_triggers = SignalDict() # Signal -> str/domain
self._domain_signals = dict() # str/domain -> {Signal} self._domain_signals = dict() # str/domain -> {Signal}
self._signals = ValueSet() # {Signal} self._signals = SignalSet() # {Signal}
self._comb_signals = ValueSet() # {Signal} self._comb_signals = SignalSet() # {Signal}
self._sync_signals = ValueSet() # {Signal} self._sync_signals = SignalSet() # {Signal}
self._user_signals = ValueSet() # {Signal} self._user_signals = SignalSet() # {Signal}
self._started = False self._started = False
self._timestamp = 0. self._timestamp = 0.
@ -306,12 +306,12 @@ class Simulator:
self._wait_deadline = dict() # process -> float/timestamp self._wait_deadline = dict() # process -> float/timestamp
self._wait_tick = dict() # process -> str/domain self._wait_tick = dict() # process -> str/domain
self._funclets = ValueDict() # Signal -> set(lambda) self._funclets = SignalDict() # Signal -> set(lambda)
self._vcd_file = vcd_file self._vcd_file = vcd_file
self._vcd_writer = None self._vcd_writer = None
self._vcd_signals = ValueDict() # signal -> set(vcd_signal) self._vcd_signals = SignalDict() # signal -> set(vcd_signal)
self._vcd_names = ValueDict() # signal -> str/name self._vcd_names = SignalDict() # signal -> str/name
self._gtkw_file = gtkw_file self._gtkw_file = gtkw_file
self._traces = traces self._traces = traces
@ -381,7 +381,7 @@ class Simulator:
self._domain_triggers[cd.clk] = domain self._domain_triggers[cd.clk] = domain
if cd.rst is not None: if cd.rst is not None:
self._domain_triggers[cd.rst] = domain self._domain_triggers[cd.rst] = domain
self._domain_signals[domain] = ValueSet() self._domain_signals[domain] = SignalSet()
hierarchy = {} hierarchy = {}
def add_fragment(fragment, scope=()): def add_fragment(fragment, scope=()):

View file

@ -213,9 +213,9 @@ class LegalizeValue(Exception):
class _ValueCompilerState: class _ValueCompilerState:
def __init__(self, rtlil): def __init__(self, rtlil):
self.rtlil = rtlil self.rtlil = rtlil
self.wires = ast.ValueDict() self.wires = ast.SignalDict()
self.driven = ast.ValueDict() self.driven = ast.SignalDict()
self.ports = ast.ValueDict() self.ports = ast.SignalDict()
self.expansions = ast.ValueDict() self.expansions = ast.ValueDict()

View file

@ -13,7 +13,7 @@ __all__ = [
"Array", "ArrayProxy", "Array", "ArrayProxy",
"Signal", "ClockSignal", "ResetSignal", "Signal", "ClockSignal", "ResetSignal",
"Statement", "Assign", "Switch", "Delay", "Tick", "Passive", "Statement", "Assign", "Switch", "Delay", "Tick", "Passive",
"ValueKey", "ValueDict", "ValueSet", "ValueKey", "ValueDict", "ValueSet", "SignalKey", "SignalDict", "SignalSet",
] ]
@ -28,14 +28,14 @@ class DUID:
class Value(metaclass=ABCMeta): class Value(metaclass=ABCMeta):
@staticmethod @staticmethod
def wrap(obj): def wrap(obj):
"""Ensures that the passed object is a Migen value. Booleans and integers """Ensures that the passed object is an nMigen value. Booleans and integers
are automatically wrapped into ``Const``.""" are automatically wrapped into ``Const``."""
if isinstance(obj, Value): if isinstance(obj, Value):
return obj return obj
elif isinstance(obj, (bool, int)): elif isinstance(obj, (bool, int)):
return Const(obj) return Const(obj)
else: else:
raise TypeError("Object '{!r}' is not a Migen value".format(obj)) raise TypeError("Object '{!r}' is not an nMigen value".format(obj))
def __init__(self, src_loc_at=0): def __init__(self, src_loc_at=0):
super().__init__() super().__init__()
@ -47,7 +47,7 @@ class Value(metaclass=ABCMeta):
self.src_loc = (tb[0].filename, tb[0].lineno) self.src_loc = (tb[0].filename, tb[0].lineno)
def __bool__(self): def __bool__(self):
raise TypeError("Attempted to convert Migen value to boolean") raise TypeError("Attempted to convert nMigen value to boolean")
def __invert__(self): def __invert__(self):
return Operator("~", [self]) return Operator("~", [self])
@ -801,7 +801,7 @@ class Statement:
if isinstance(obj, Statement): if isinstance(obj, Statement):
return _StatementList([obj]) return _StatementList([obj])
else: else:
raise TypeError("Object '{!r}' is not a Migen statement".format(obj)) raise TypeError("Object '{!r}' is not an nMigen statement".format(obj))
class Assign(Statement): class Assign(Statement):
@ -936,7 +936,8 @@ class _MappedKeyDict(MutableMapping, _MappedKeyCollection):
def __repr__(self): def __repr__(self):
pairs = ["({!r}, {!r})".format(k, v) for k, v in self.items()] pairs = ["({!r}, {!r})".format(k, v) for k, v in self.items()]
return "{}([{}])".format(type(self).__name__, ", ".join(pairs)) return "{}.{}([{}])".format(type(self).__module__, type(self).__name__,
", ".join(pairs))
class _MappedKeySet(MutableSet, _MappedKeyCollection): class _MappedKeySet(MutableSet, _MappedKeyCollection):
@ -967,7 +968,8 @@ class _MappedKeySet(MutableSet, _MappedKeyCollection):
return len(self._storage) return len(self._storage)
def __repr__(self): def __repr__(self):
return "{}({})".format(type(self).__name__, ", ".join(repr(x) for x in self)) return "{}.{}({})".format(type(self).__module__, type(self).__name__,
", ".join(repr(x) for x in self))
class ValueKey: class ValueKey:
@ -1060,3 +1062,34 @@ class ValueDict(_MappedKeyDict):
class ValueSet(_MappedKeySet): class ValueSet(_MappedKeySet):
_map_key = ValueKey _map_key = ValueKey
_unmap_key = lambda self, key: key.value _unmap_key = lambda self, key: key.value
class SignalKey:
def __init__(self, signal):
if not isinstance(signal, Signal):
raise TypeError("Object '{!r}' is not an nMigen signal")
self.signal = signal
def __hash__(self):
return hash(self.signal.duid)
def __eq__(self, other):
return isinstance(other, SignalKey) and self.signal.duid == other.signal.duid
def __lt__(self, other):
if not isinstance(other, SignalKey):
raise TypeError("Object '{!r}' cannot be compared to a SignalKey")
return self.signal.duid < other.signal.duid
def __repr__(self):
return "<{}.SignalKey {!r}>".format(__name__, self.signal)
class SignalDict(_MappedKeyDict):
_map_key = SignalKey
_unmap_key = lambda self, key: key.signal
class SignalSet(_MappedKeySet):
_map_key = SignalKey
_unmap_key = lambda self, key: key.signal

View file

@ -102,7 +102,7 @@ class Module(_ModuleBuilderRoot):
self._ctrl_context = None self._ctrl_context = None
self._ctrl_stack = [] self._ctrl_stack = []
self._driving = ValueDict() self._driving = SignalDict()
self._submodules = [] self._submodules = []
self._domains = [] self._domains = []

View file

@ -15,7 +15,7 @@ class DriverConflict(UserWarning):
class Fragment: class Fragment:
def __init__(self): def __init__(self):
self.ports = ValueDict() self.ports = SignalDict()
self.drivers = OrderedDict() self.drivers = OrderedDict()
self.statements = [] self.statements = []
self.domains = OrderedDict() self.domains = OrderedDict()
@ -31,7 +31,7 @@ class Fragment:
def add_driver(self, signal, domain=None): def add_driver(self, signal, domain=None):
if domain not in self.drivers: if domain not in self.drivers:
self.drivers[domain] = ValueSet() self.drivers[domain] = SignalSet()
self.drivers[domain].add(signal) self.drivers[domain].add(signal)
def iter_drivers(self): def iter_drivers(self):
@ -51,7 +51,7 @@ class Fragment:
yield domain, signal yield domain, signal
def iter_signals(self): def iter_signals(self):
signals = ValueSet() signals = SignalSet()
signals |= self.ports.keys() signals |= self.ports.keys()
for domain, domain_signals in self.drivers.items(): for domain, domain_signals in self.drivers.items():
if domain is not None: if domain is not None:
@ -81,7 +81,7 @@ class Fragment:
def _resolve_driver_conflicts(self, hierarchy=("top",), mode="warn"): def _resolve_driver_conflicts(self, hierarchy=("top",), mode="warn"):
assert mode in ("silent", "warn", "error") assert mode in ("silent", "warn", "error")
driver_subfrags = ValueDict() driver_subfrags = SignalDict()
# For each signal driven by this fragment and/or its subfragments, determine which # For each signal driven by this fragment and/or its subfragments, determine which
# subfragments also drive it. # subfragments also drive it.
@ -147,7 +147,7 @@ class Fragment:
return self._resolve_driver_conflicts(hierarchy, mode) return self._resolve_driver_conflicts(hierarchy, mode)
# Nothing was flattened, we're done! # Nothing was flattened, we're done!
return ValueSet(driver_subfrags.keys()) return SignalSet(driver_subfrags.keys())
def _propagate_domains_up(self, hierarchy=("top",)): def _propagate_domains_up(self, hierarchy=("top",)):
from .xfrm import DomainRenamer from .xfrm import DomainRenamer
@ -229,8 +229,8 @@ class Fragment:
def _propagate_ports(self, ports): def _propagate_ports(self, ports):
# Collect all signals we're driving (on LHS of statements), and signals we're using # Collect all signals we're driving (on LHS of statements), and signals we're using
# (on RHS of statements, or in clock domains). # (on RHS of statements, or in clock domains).
self_driven = union(s._lhs_signals() for s in self.statements) or ValueSet() self_driven = union(s._lhs_signals() for s in self.statements) or SignalSet()
self_used = union(s._rhs_signals() for s in self.statements) or ValueSet() self_used = union(s._rhs_signals() for s in self.statements) or SignalSet()
for domain, _ in self.iter_sync(): for domain, _ in self.iter_sync():
cd = self.domains[domain] cd = self.domains[domain]
self_used.add(cd.clk) self_used.add(cd.clk)

View file

@ -369,7 +369,7 @@ class DSLTestCase(FHDLTestCase):
) )
""") """)
self.assertEqual(f1.drivers, { self.assertEqual(f1.drivers, {
None: ValueSet((self.c1,)) None: SignalSet((self.c1,))
}) })
self.assertEqual(len(f1.subfragments), 1) self.assertEqual(len(f1.subfragments), 1)
(f2, f2_name), = f1.subfragments (f2, f2_name), = f1.subfragments
@ -381,7 +381,7 @@ class DSLTestCase(FHDLTestCase):
) )
""") """)
self.assertEqual(f2.drivers, { self.assertEqual(f2.drivers, {
None: ValueSet((self.c2,)), None: SignalSet((self.c2,)),
"sync": ValueSet((self.c3,)) "sync": SignalSet((self.c3,))
}) })
self.assertEqual(len(f2.subfragments), 0) self.assertEqual(len(f2.subfragments), 0)

View file

@ -25,12 +25,12 @@ class FragmentPortsTestCase(FHDLTestCase):
self.assertEqual(list(f.iter_ports()), []) self.assertEqual(list(f.iter_ports()), [])
f._propagate_ports(ports=()) f._propagate_ports(ports=())
self.assertEqual(f.ports, ValueDict([])) self.assertEqual(f.ports, SignalDict([]))
def test_iter_signals(self): def test_iter_signals(self):
f = Fragment() f = Fragment()
f.add_ports(self.s1, self.s2, kind="io") f.add_ports(self.s1, self.s2, kind="io")
self.assertEqual(ValueSet((self.s1, self.s2)), f.iter_signals()) self.assertEqual(SignalSet((self.s1, self.s2)), f.iter_signals())
def test_self_contained(self): def test_self_contained(self):
f = Fragment() f = Fragment()
@ -40,7 +40,7 @@ class FragmentPortsTestCase(FHDLTestCase):
) )
f._propagate_ports(ports=()) f._propagate_ports(ports=())
self.assertEqual(f.ports, ValueDict([])) self.assertEqual(f.ports, SignalDict([]))
def test_infer_input(self): def test_infer_input(self):
f = Fragment() f = Fragment()
@ -49,7 +49,7 @@ class FragmentPortsTestCase(FHDLTestCase):
) )
f._propagate_ports(ports=()) f._propagate_ports(ports=())
self.assertEqual(f.ports, ValueDict([ self.assertEqual(f.ports, SignalDict([
(self.s1, "i") (self.s1, "i")
])) ]))
@ -60,7 +60,7 @@ class FragmentPortsTestCase(FHDLTestCase):
) )
f._propagate_ports(ports=(self.c1,)) f._propagate_ports(ports=(self.c1,))
self.assertEqual(f.ports, ValueDict([ self.assertEqual(f.ports, SignalDict([
(self.s1, "i"), (self.s1, "i"),
(self.c1, "o") (self.c1, "o")
])) ]))
@ -76,8 +76,8 @@ class FragmentPortsTestCase(FHDLTestCase):
) )
f1.add_subfragment(f2) f1.add_subfragment(f2)
f1._propagate_ports(ports=()) f1._propagate_ports(ports=())
self.assertEqual(f1.ports, ValueDict()) self.assertEqual(f1.ports, SignalDict())
self.assertEqual(f2.ports, ValueDict([ self.assertEqual(f2.ports, SignalDict([
(self.s1, "o"), (self.s1, "o"),
])) ]))
@ -89,10 +89,10 @@ class FragmentPortsTestCase(FHDLTestCase):
) )
f1.add_subfragment(f2) f1.add_subfragment(f2)
f1._propagate_ports(ports=()) f1._propagate_ports(ports=())
self.assertEqual(f1.ports, ValueDict([ self.assertEqual(f1.ports, SignalDict([
(self.s1, "i"), (self.s1, "i"),
])) ]))
self.assertEqual(f2.ports, ValueDict([ self.assertEqual(f2.ports, SignalDict([
(self.s1, "i"), (self.s1, "i"),
])) ]))
@ -108,10 +108,10 @@ class FragmentPortsTestCase(FHDLTestCase):
f1.add_subfragment(f2) f1.add_subfragment(f2)
f1._propagate_ports(ports=(self.c2,)) f1._propagate_ports(ports=(self.c2,))
self.assertEqual(f1.ports, ValueDict([ self.assertEqual(f1.ports, SignalDict([
(self.c2, "o"), (self.c2, "o"),
])) ]))
self.assertEqual(f2.ports, ValueDict([ self.assertEqual(f2.ports, SignalDict([
(self.c2, "o"), (self.c2, "o"),
])) ]))
@ -125,7 +125,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f.add_driver(self.c1, "sync") f.add_driver(self.c1, "sync")
f._propagate_ports(ports=()) f._propagate_ports(ports=())
self.assertEqual(f.ports, ValueDict([ self.assertEqual(f.ports, SignalDict([
(self.s1, "i"), (self.s1, "i"),
(sync.clk, "i"), (sync.clk, "i"),
(sync.rst, "i"), (sync.rst, "i"),
@ -141,7 +141,7 @@ class FragmentPortsTestCase(FHDLTestCase):
f.add_driver(self.c1, "sync") f.add_driver(self.c1, "sync")
f._propagate_ports(ports=()) f._propagate_ports(ports=())
self.assertEqual(f.ports, ValueDict([ self.assertEqual(f.ports, SignalDict([
(self.s1, "i"), (self.s1, "i"),
(sync.clk, "i"), (sync.clk, "i"),
])) ]))
@ -157,9 +157,9 @@ class FragmentDomainsTestCase(FHDLTestCase):
f = Fragment() f = Fragment()
f.add_domains(cd1, cd2) f.add_domains(cd1, cd2)
f.add_driver(s1, "cd1") f.add_driver(s1, "cd1")
self.assertEqual(ValueSet((cd1.clk, cd1.rst, s1)), f.iter_signals()) self.assertEqual(SignalSet((cd1.clk, cd1.rst, s1)), f.iter_signals())
f.add_driver(s2, "cd2") f.add_driver(s2, "cd2")
self.assertEqual(ValueSet((cd1.clk, cd1.rst, cd2.clk, s1, s2)), f.iter_signals()) self.assertEqual(SignalSet((cd1.clk, cd1.rst, cd2.clk, s1, s2)), f.iter_signals())
def test_propagate_up(self): def test_propagate_up(self):
cd = ClockDomain() cd = ClockDomain()
@ -315,8 +315,8 @@ class FragmentDriverConflictTestCase(FHDLTestCase):
) )
""") """)
self.assertEqual(self.f1.drivers, { self.assertEqual(self.f1.drivers, {
None: ValueSet((self.s1,)), None: SignalSet((self.s1,)),
"sync": ValueSet((self.c1, self.c2)), "sync": SignalSet((self.c1, self.c2)),
}) })
def test_conflict_self_sub_error(self): def test_conflict_self_sub_error(self):

View file

@ -38,8 +38,8 @@ class DomainRenamerTestCase(FHDLTestCase):
) )
""") """)
self.assertEqual(f.drivers, { self.assertEqual(f.drivers, {
None: ValueSet((self.s1, self.s2)), None: SignalSet((self.s1, self.s2)),
"pix": ValueSet((self.s3,)), "pix": SignalSet((self.s3,)),
}) })
def test_rename_multi(self): def test_rename_multi(self):