hdl.ast: implement Array and ArrayProxy.

This commit is contained in:
whitequark 2018-12-15 17:16:22 +00:00
parent 1580b6e542
commit 80c5343600
5 changed files with 199 additions and 8 deletions

View file

@ -1,4 +1,4 @@
from .hdl.ast import Value, Const, Mux, Cat, Repl, Signal, ClockSignal, ResetSignal from .hdl.ast import Value, Const, Mux, Cat, Repl, Array, Signal, ClockSignal, ResetSignal
from .hdl.dsl import Module from .hdl.dsl import Module
from .hdl.cd import ClockDomain from .hdl.cd import ClockDomain
from .hdl.ir import Fragment from .hdl.ir import Fragment

View file

@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod
import builtins import builtins
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Iterable, MutableMapping, MutableSet from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence
from .. import tracer from .. import tracer
from ..tools import * from ..tools import *
@ -10,6 +10,7 @@ from ..tools import *
__all__ = [ __all__ = [
"Value", "Const", "C", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Value", "Const", "C", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
"Array", "ArrayProxy",
"Signal", "ClockSignal", "ResetSignal", "Signal", "ClockSignal", "ResetSignal",
"Statement", "Assign", "Switch", "Delay", "Tick", "Passive", "Statement", "Assign", "Switch", "Delay", "Tick", "Passive",
"ValueKey", "ValueDict", "ValueSet", "ValueKey", "ValueDict", "ValueSet",
@ -39,8 +40,7 @@ class Value(metaclass=ABCMeta):
def __init__(self, src_loc_at=0): def __init__(self, src_loc_at=0):
super().__init__() super().__init__()
src_loc_at += 3 tb = traceback.extract_stack(limit=3 + src_loc_at)
tb = traceback.extract_stack(limit=src_loc_at)
if len(tb) < src_loc_at: if len(tb) < src_loc_at:
self.src_loc = None self.src_loc = None
else: else:
@ -664,6 +664,127 @@ class ResetSignal(Value):
return "(rst {})".format(self.domain) return "(rst {})".format(self.domain)
class Array(MutableSequence):
"""Addressable multiplexer.
An array is similar to a ``list`` that can also be indexed by ``Value``s; indexing by an integer or a slice works the same as for Python lists, but indexing by a ``Value`` results
in a proxy.
The array proxy can be used as an ordinary ``Value``, i.e. participate in calculations and
assignments, provided that all elements of the array are values. The array proxy also supports
attribute access and further indexing, each returning another array proxy; this means that
the results of indexing into arrays, arrays of records, and arrays of arrays can all
be used as first-class values.
It is an error to change an array or any of its elements after an array proxy was created.
Changing the array directly will raise an exception. However, it is not possible to detect
the elements being modified; if an element's attribute or element is modified after the proxy
for it has been created, the proxy will refer to stale data.
Examples
--------
Simple array::
gpios = Array(Signal() for _ in range(10))
with m.If(bus.we):
m.d.sync += gpios[bus.adr].eq(bus.dat_w)
with m.Else():
m.d.sync += bus.dat_r.eq(gpios[bus.adr])
Multidimensional array::
mult = Array(Array(x * y for y in range(10)) for x in range(10))
a = Signal(max=10)
b = Signal(max=10)
r = Signal(8)
m.d.comb += r.eq(mult[a][b])
Array of records::
layout = [
("re", 1),
("dat_r", 16),
]
buses = Array(Record(layout) for busno in range(4))
master = Record(layout)
m.d.comb += [
buses[sel].re.eq(master.re),
master.dat_r.eq(buses[sel].dat_r),
]
"""
def __init__(self, iterable):
self._inner = list(iterable)
self._proxy_at = None
self._mutable = True
def __getitem__(self, index):
if isinstance(index, Value):
if self._mutable:
tb = traceback.extract_stack(limit=2)
self._proxy_at = (tb[0].filename, tb[0].lineno)
self._mutable = False
return ArrayProxy(self, index)
else:
return self._inner[index]
def __len__(self):
return len(self._inner)
def _check_mutability(self):
if not self._mutable:
raise ValueError("Array can no longer be mutated after it was indexed with a value "
"at {}:{}".format(*self._proxy_at))
def __setitem__(self, index, value):
self._check_mutability()
self._inner[index] = value
def __delitem__(self, index):
self._check_mutability()
del self._inner[index]
def insert(self, index, value):
self._check_mutability()
self._inner.insert(index, value)
def __repr__(self):
return "(array{} [{}])".format(" mutable" if self._mutable else "",
", ".join(map(repr, self._inner)))
class ArrayProxy(Value):
def __init__(self, elems, index):
super().__init__(src_loc_at=1)
self.elems = elems
self.index = Value.wrap(index)
def __getattr__(self, attr):
return ArrayProxy([getattr(elem, attr) for elem in self.elems], self.index)
def __getitem__(self, index):
return ArrayProxy([ elem[index] for elem in self.elems], self.index)
def _iter_as_values(self):
return (Value.wrap(elem) for elem in self.elems)
def shape(self):
bits, sign = 0, False
for elem_bits, elem_sign in (elem.shape() for elem in self._iter_as_values()):
bits = max(bits, elem_bits + elem_sign)
sign = max(sign, elem_sign)
return bits, sign
def _lhs_signals(self):
return union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet())
def _rhs_signals(self):
return union((elem._rhs_signals() for elem in self._iter_as_values()), start=ValueSet())
def __repr__(self):
return "(proxy (array [{}]) {!r})".format(", ".join(map(repr, self.elems)), self.index)
class _StatementList(list): class _StatementList(list):
def __repr__(self): def __repr__(self):
return "({})".format(" ".join(map(repr, self))) return "({})".format(" ".join(map(repr, self)))
@ -713,11 +834,13 @@ class Switch(Statement):
self.cases[key] = Statement.wrap(stmts) self.cases[key] = Statement.wrap(stmts)
def _lhs_signals(self): def _lhs_signals(self):
signals = union(s._lhs_signals() for ss in self.cases.values() for s in ss) or ValueSet() signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
start=ValueSet())
return signals return signals
def _rhs_signals(self): def _rhs_signals(self):
signals = union(s._rhs_signals() for ss in self.cases.values() for s in ss) or ValueSet() signals = union((s._rhs_signals() for ss in self.cases.values() for s in ss),
start=ValueSet())
return self.test._rhs_signals() | signals return self.test._rhs_signals() | signals
def __repr__(self): def __repr__(self):

View file

@ -334,6 +334,66 @@ class ReplTestCase(FHDLTestCase):
self.assertEqual(repr(s), "(repl (const 4'd10) 3)") self.assertEqual(repr(s), "(repl (const 4'd10) 3)")
class ArrayTestCase(FHDLTestCase):
def test_acts_like_array(self):
a = Array([1,2,3])
self.assertSequenceEqual(a, [1,2,3])
self.assertEqual(a[1], 2)
a[1] = 4
self.assertSequenceEqual(a, [1,4,3])
del a[1]
self.assertSequenceEqual(a, [1,3])
a.insert(1, 2)
self.assertSequenceEqual(a, [1,2,3])
def test_becomes_immutable(self):
a = Array([1,2,3])
s1 = Signal(max=len(a))
s2 = Signal(max=len(a))
v1 = a[s1]
v2 = a[s2]
with self.assertRaisesRegex(ValueError,
regex=r"^Array can no longer be mutated after it was indexed with a value at "):
a[1] = 2
with self.assertRaisesRegex(ValueError,
regex=r"^Array can no longer be mutated after it was indexed with a value at "):
del a[1]
with self.assertRaisesRegex(ValueError,
regex=r"^Array can no longer be mutated after it was indexed with a value at "):
a.insert(1, 2)
def test_repr(self):
a = Array([1,2,3])
self.assertEqual(repr(a), "(array mutable [1, 2, 3])")
s = Signal(max=len(a))
v = a[s]
self.assertEqual(repr(a), "(array [1, 2, 3])")
class ArrayProxyTestCase(FHDLTestCase):
def test_index_shape(self):
m = Array(Array(x * y for y in range(1, 4)) for x in range(1, 4))
a = Signal(max=3)
b = Signal(max=3)
v = m[a][b]
self.assertEqual(v.shape(), (4, False))
def test_attr_shape(self):
from collections import namedtuple
pair = namedtuple("pair", ("p", "n"))
a = Array(pair(i, -i) for i in range(10))
s = Signal(max=len(a))
v = a[s]
self.assertEqual(v.p.shape(), (4, False))
self.assertEqual(v.n.shape(), (6, True))
def test_repr(self):
a = Array([1, 2, 3])
s = Signal(max=3)
v = a[s]
self.assertEqual(repr(v), "(proxy (array [1, 2, 3]) (sig s))")
class SignalTestCase(FHDLTestCase): class SignalTestCase(FHDLTestCase):
def test_shape(self): def test_shape(self):
s1 = Signal() s1 = Signal()

View file

@ -25,6 +25,14 @@ class FHDLTestCase(unittest.TestCase):
# WTF? unittest.assertRaises is completely broken. # WTF? unittest.assertRaises is completely broken.
self.assertEqual(str(cm.exception), msg) self.assertEqual(str(cm.exception), msg)
@contextmanager
def assertRaisesRegex(self, exception, regex=None):
with super().assertRaises(exception) as cm:
yield
if regex is not None:
# unittest.assertRaisesRegex also seems broken...
self.assertRegex(str(cm.exception), regex)
@contextmanager @contextmanager
def assertWarns(self, category, msg=None): def assertWarns(self, category, msg=None):
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:

View file

@ -14,8 +14,8 @@ def flatten(i):
yield e yield e
def union(i): def union(i, start=None):
r = None r = start
for e in i: for e in i:
if r is None: if r is None:
r = e r = e