back.pysim: handle out of bounds ArrayProxy indexes.

This commit is contained in:
whitequark 2018-12-21 12:32:08 +00:00
parent 7ae7683fed
commit 48d13e47ec
2 changed files with 25 additions and 2 deletions

View file

@ -193,7 +193,12 @@ class _RHSValueCompiler(AbstractValueTransformer):
shape = value.shape() shape = value.shape()
elems = list(map(self, value.elems)) elems = list(map(self, value.elems))
index = self(value.index) index = self(value.index)
return lambda state: normalize(elems[index(state)](state), shape) def eval(state):
index_value = index(state)
if index_value >= len(elems):
index_value = len(elems) - 1
return normalize(elems[index_value](state), shape)
return eval
class _LHSValueCompiler(AbstractValueTransformer): class _LHSValueCompiler(AbstractValueTransformer):
@ -263,7 +268,10 @@ class _LHSValueCompiler(AbstractValueTransformer):
elems = list(map(self, value.elems)) elems = list(map(self, value.elems))
index = self.rhs_compiler(value.index) index = self.rhs_compiler(value.index)
def eval(state, rhs): def eval(state, rhs):
elems[index(state)](state, rhs) index_value = index(state)
if index_value >= len(elems):
index_value = len(elems) - 1
elems[index_value](state, rhs)
return eval return eval

View file

@ -184,6 +184,12 @@ class SimulatorUnitTestCase(FHDLTestCase):
self.assertStatement(stmt, [C(1)], C(4)) self.assertStatement(stmt, [C(1)], C(4))
self.assertStatement(stmt, [C(2)], C(10)) self.assertStatement(stmt, [C(2)], C(10))
def test_array_oob(self):
array = Array([1, 4, 10])
stmt = lambda y, a: y.eq(array[a])
self.assertStatement(stmt, [C(3)], C(10))
self.assertStatement(stmt, [C(4)], C(10))
def test_array_lhs(self): def test_array_lhs(self):
l = Signal(3, reset=1) l = Signal(3, reset=1)
m = Signal(3, reset=4) m = Signal(3, reset=4)
@ -194,6 +200,15 @@ class SimulatorUnitTestCase(FHDLTestCase):
self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001)) self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001))
self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001)) self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001))
def test_array_lhs_oob(self):
l = Signal(3)
m = Signal(3)
n = Signal(3)
array = Array([l, m, n])
stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))]
self.assertStatement(stmt, [C(3), C(0b001)], C(0b001000000))
self.assertStatement(stmt, [C(4), C(0b010)], C(0b010000000))
def test_array_index(self): def test_array_index(self):
array = Array(Array(x * y for y in range(10)) for x in range(10)) array = Array(Array(x * y for y in range(10)) for x in range(10))
stmt = lambda y, a, b: y.eq(array[a][b]) stmt = lambda y, a, b: y.eq(array[a][b])