From 6f5d009fad7ba98297da2ad1d55cc2f85a08f31c Mon Sep 17 00:00:00 2001 From: Wanda Date: Thu, 11 Apr 2024 11:07:58 +0200 Subject: [PATCH] sim: fix LRHS evaluation. Fixes #1269. --- amaranth/sim/_pyrtl.py | 12 ++++++++---- tests/test_sim.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index f6c4943..5505cfb 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -153,12 +153,16 @@ class _ValueCompiler(ValueVisitor, _Compiler): class _RHSValueCompiler(_ValueCompiler): - def __init__(self, state, emitter, *, mode, inputs=None): + def __init__(self, state, emitter, *, mode, inputs=None, rrhs=None): super().__init__(state, emitter) assert mode in ("curr", "next") self.mode = mode # If not None, `inputs` gets populated with RHS signals. self.inputs = inputs + # When this compiler is used to grab the "next" value from within _LHSValueCompiler, + # we still need to use "curr" mode for reading part offsets etc. Allow setting a separate + # _RhsValueCompiler for these contexts. + self.rrhs = rrhs or self def sign(self, value): value_mask = (1 << len(value)) - 1 @@ -251,7 +255,7 @@ class _RHSValueCompiler(_ValueCompiler): def on_Part(self, value): offset_mask = (1 << len(value.offset)) - 1 - offset = f"({value.stride} * ({offset_mask:#x} & {self(value.offset)}))" + offset = f"({value.stride} * ({offset_mask:#x} & {self.rrhs(value.offset)}))" return f"({(1 << value.width) - 1} & " \ f"{self(value.value)} >> {offset})" @@ -267,7 +271,7 @@ class _RHSValueCompiler(_ValueCompiler): return f"0" def on_SwitchValue(self, value): - gen_test = self.emitter.def_var("test", f"{(1 << len(value.test)) - 1:#x} & {self(value.test)}") + gen_test = self.emitter.def_var("test", f"{(1 << len(value.test)) - 1:#x} & {self.rrhs(value.test)}") gen_value = self.emitter.def_var("rhs_switch", "0") def case_handler(patterns, elem): self.emitter.append(f"{gen_value} = {self.sign(elem)}") @@ -290,7 +294,7 @@ class _LHSValueCompiler(_ValueCompiler): self.rrhs = rhs # `lrhs` is used to translate the read part of a read-modify-write cycle during partial # update of an lvalue. - self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None) + self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None, rrhs=rhs) # If not None, `outputs` gets populated with signals on LHS. self.outputs = outputs diff --git a/tests/test_sim.py b/tests/test_sim.py index fa4f881..6b9affd 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -366,6 +366,16 @@ class SimulatorUnitTestCase(FHDLTestCase): self.assertStatement(stmt, [C(1), C(0b010)], C(0b001110101, 9)) self.assertStatement(stmt, [C(2), C(0b100)], C(0b001001001, 9)) + def test_array_lhs_heterogenous_slice(self): + l = Signal(1, init=1) + m = Signal(3, init=4) + n = Signal(5, init=7) + array = Array([l, m, n]) + stmt = lambda y, a, b: [array[a].as_value()[2:].eq(b), y.eq(Cat(*array))] + self.assertStatement(stmt, [C(0), C(0b000)], C(0b001111001, 9)) + self.assertStatement(stmt, [C(1), C(0b010)], C(0b001110001, 9)) + self.assertStatement(stmt, [C(2), C(0b100)], C(0b100111001, 9)) + def test_array_lhs_oob(self): l = Signal(3) m = Signal(3)