back.pysim: implement ArrayProxy.

This commit is contained in:
whitequark 2018-12-15 19:37:36 +00:00
parent 80c5343600
commit 54fb999c99
4 changed files with 39 additions and 2 deletions

View file

@ -148,6 +148,12 @@ class _RHSValueCompiler(ValueTransformer):
return normalize(result, shape) return normalize(result, shape)
return eval return eval
def on_ArrayProxy(self, value):
shape = value.shape()
elems = list(map(self, value.elems))
index = self(value.index)
return lambda state: normalize(elems[index(state)](state), shape)
class _StatementCompiler(StatementTransformer): class _StatementCompiler(StatementTransformer):
def __init__(self): def __init__(self):

View file

@ -776,10 +776,12 @@ class ArrayProxy(Value):
return bits, sign return bits, sign
def _lhs_signals(self): def _lhs_signals(self):
return union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet()) signals = union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet())
return signals
def _rhs_signals(self): def _rhs_signals(self):
return union((elem._rhs_signals() for elem in self._iter_as_values()), start=ValueSet()) signals = union((elem._rhs_signals() for elem in self._iter_as_values()), start=ValueSet())
return self.index._rhs_signals() | signals
def __repr__(self): def __repr__(self):
return "(proxy (array [{}]) {!r})".format(", ".join(map(repr, self.elems)), self.index) return "(proxy (array [{}]) {!r})".format(", ".join(map(repr, self.elems)), self.index)

View file

@ -40,6 +40,10 @@ class ValueTransformer:
def on_Repl(self, value): def on_Repl(self, value):
return Repl(self.on_value(value.value), value.count) return Repl(self.on_value(value.value), value.count)
def on_ArrayProxy(self, value):
return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
self.on_value(value.index))
def on_unknown_value(self, value): def on_unknown_value(self, value):
raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov: raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov:
@ -62,6 +66,8 @@ class ValueTransformer:
new_value = self.on_Cat(value) new_value = self.on_Cat(value)
elif isinstance(value, Repl): elif isinstance(value, Repl):
new_value = self.on_Repl(value) new_value = self.on_Repl(value)
elif isinstance(value, ArrayProxy):
new_value = self.on_ArrayProxy(value)
else: else:
new_value = self.on_unknown_value(value) new_value = self.on_unknown_value(value)
if isinstance(new_value, Value): if isinstance(new_value, Value):

View file

@ -136,3 +136,26 @@ class SimulatorUnitTestCase(FHDLTestCase):
def test_repl(self): def test_repl(self):
stmt = lambda a: Repl(a, 3) stmt = lambda a: Repl(a, 3)
self.assertOperator(stmt, [C(0b10, 2)], C(0b101010, 6)) self.assertOperator(stmt, [C(0b10, 2)], C(0b101010, 6))
def test_array(self):
array = Array([1, 4, 10])
stmt = lambda a: array[a]
self.assertOperator(stmt, [C(0)], C(1))
self.assertOperator(stmt, [C(1)], C(4))
self.assertOperator(stmt, [C(2)], C(10))
def test_array_index(self):
array = Array(Array(x * y for y in range(10)) for x in range(10))
stmt = lambda a, b: array[a][b]
for x in range(10):
for y in range(10):
self.assertOperator(stmt, [C(x), C(y)], C(x * y))
def test_array_attr(self):
from collections import namedtuple
pair = namedtuple("pair", ("p", "n"))
array = Array(pair(x, -x) for x in range(10))
stmt = lambda a: array[a].p + array[a].n
for i in range(10):
self.assertOperator(stmt, [C(i)], C(0))