diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 9358072..d60e6af 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -18,7 +18,7 @@ from .._unused import * __all__ = [ "SyntaxError", "SyntaxWarning", "Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike", - "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "SwitchValue", + "Value", "Const", "C", "AnyValue", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "SwitchValue", "Array", "ArrayProxy", "Signal", "ClockSignal", "ResetSignal", "ValueCastable", "ValueLike", diff --git a/amaranth/sim/_pycoro.py b/amaranth/sim/_pycoro.py index 75e9d19..0c705a9 100644 --- a/amaranth/sim/_pycoro.py +++ b/amaranth/sim/_pycoro.py @@ -5,7 +5,7 @@ from ..hdl._ast import Statement, Assign, SignalSet, ValueCastable from ..hdl._mem import MemorySimRead, MemorySimWrite from .core import Tick, Settle, Delay, Passive, Active from ._base import BaseProcess, BaseMemoryState -from ._pyrtl import _ValueCompiler, _RHSValueCompiler, _StatementCompiler +from ._pyeval import eval_value, eval_assign __all__ = ["PyCoroProcess"] @@ -28,11 +28,6 @@ class PyCoroProcess(BaseProcess): self.passive = False self.coroutine = self.constructor() - self.exec_locals = { - "slots": self.state.slots, - "result": None, - **_ValueCompiler.helpers - } self.waits_on = SignalSet() def src_loc(self): @@ -87,14 +82,11 @@ class PyCoroProcess(BaseProcess): if isinstance(command, ValueCastable): command = Value.cast(command) if isinstance(command, Value): - exec(_RHSValueCompiler.compile(self.state, command, mode="curr"), - self.exec_locals) - response = Const(self.exec_locals["result"], command.shape()).value + response = eval_value(self.state, command) - elif isinstance(command, Statement): - exec(_StatementCompiler.compile(self.state, command), - self.exec_locals) - if isinstance(command, Assign) and self.testbench: + elif isinstance(command, Assign): + eval_assign(self.state, command.lhs, eval_value(self.state, command.rhs)) + if self.testbench: return True # assignment; run a delta cycle elif type(command) is Tick: @@ -132,21 +124,15 @@ class PyCoroProcess(BaseProcess): self.passive = False elif type(command) is MemorySimRead: - exec(_RHSValueCompiler.compile(self.state, command._addr, mode="curr"), - self.exec_locals) - addr = Const(self.exec_locals["result"], command._addr.shape()).value + addr = eval_value(self.state, command._addr) index = self.state.get_memory(command._memory) state = self.state.slots[index] assert isinstance(state, BaseMemoryState) response = state.read(addr) elif type(command) is MemorySimWrite: - exec(_RHSValueCompiler.compile(self.state, command._addr, mode="curr"), - self.exec_locals) - addr = Const(self.exec_locals["result"], command._addr.shape()).value - exec(_RHSValueCompiler.compile(self.state, command._data, mode="curr"), - self.exec_locals) - data = Const(self.exec_locals["result"], command._data.shape()).value + addr = eval_value(self.state, command._addr) + data = eval_value(self.state, command._data) index = self.state.get_memory(command._memory) state = self.state.slots[index] assert isinstance(state, BaseMemoryState) diff --git a/amaranth/sim/_pyeval.py b/amaranth/sim/_pyeval.py new file mode 100644 index 0000000..510b0e9 --- /dev/null +++ b/amaranth/sim/_pyeval.py @@ -0,0 +1,184 @@ +from amaranth.hdl._ast import * + + +def _eval_matches(test, patterns): + if patterns is None: + return True + for pattern in patterns: + if isinstance(pattern, str): + mask = int("".join("0" if b == "-" else "1" for b in pattern), 2) + value = int("".join("0" if b == "-" else b for b in pattern), 2) + if value == (mask & test): + return True + else: + if pattern == test: + return True + return False + + +def eval_value(sim, value): + if isinstance(value, Const): + return value.value + elif isinstance(value, Operator): + if len(value.operands) == 1: + op_a = eval_value(sim, value.operands[0]) + if value.operator in ("u", "s"): + width = value.shape().width + res = op_a + res &= (1 << width) - 1 + if value.operator == "s" and res & (1 << (width - 1)): + res |= -1 << (width - 1) + return res + elif value.operator == "-": + return -op_a + elif value.operator == "~": + shape = value.shape() + if shape.signed: + return ~op_a + else: + return ~op_a & ((1 << shape.width) - 1) + elif value.operator in ("b", "r|"): + return int(op_a != 0) + elif value.operator == "r&": + width = value.operands[0].shape().width + mask = (1 << width) - 1 + return int((op_a & mask) == mask) + elif value.operator == "r^": + width = value.operands[0].shape().width + mask = (1 << width) - 1 + # Believe it or not, this is the fastest way to compute a sideways XOR in Python. + return format(op_a & mask, 'b').count('1') % 2 + elif len(value.operands) == 2: + op_a = eval_value(sim, value.operands[0]) + op_b = eval_value(sim, value.operands[1]) + if value.operator == "|": + return op_a | op_b + elif value.operator == "&": + return op_a & op_b + elif value.operator == "^": + return op_a ^ op_b + elif value.operator == "+": + return op_a + op_b + elif value.operator == "-": + return op_a - op_b + elif value.operator == "*": + return op_a * op_b + elif value.operator == "//": + if op_b == 0: + return 0 + return op_a // op_b + elif value.operator == "%": + if op_b == 0: + return 0 + return op_a % op_b + elif value.operator == "<<": + return op_a << op_b + elif value.operator == ">>": + return op_a >> op_b + elif value.operator == "==": + return int(op_a == op_b) + elif value.operator == "!=": + return int(op_a != op_b) + elif value.operator == "<": + return int(op_a < op_b) + elif value.operator == "<=": + return int(op_a <= op_b) + elif value.operator == ">": + return int(op_a > op_b) + elif value.operator == ">=": + return int(op_a >= op_b) + assert False # :nocov: + elif isinstance(value, Slice): + res = eval_value(sim, value.value) + res >>= value.start + width = value.stop - value.start + return res & ((1 << width) - 1) + elif isinstance(value, Part): + res = eval_value(sim, value.value) + offset = eval_value(sim, value.offset) + offset *= value.stride + res >>= offset + return res & ((1 << value.width) - 1) + elif isinstance(value, Concat): + res = 0 + pos = 0 + for part in value.parts: + width = len(part) + part = eval_value(sim, part) + part &= (1 << width) - 1 + res |= part << pos + pos += width + return res + elif isinstance(value, SwitchValue): + test = eval_value(sim, value.test) + for patterns, val in value.cases: + if _eval_matches(test, patterns): + return eval_value(sim, val) + return 0 + elif isinstance(value, Signal): + slot = sim.get_signal(value) + return sim.slots[slot].curr + elif isinstance(value, (ResetSignal, ClockSignal, AnyValue, Initial)): + raise ValueError(f"Value {value!r} cannot be used in simulation") + else: + assert False # :nocov: + + +def _eval_assign_inner(sim, lhs, lhs_start, rhs, rhs_len): + if isinstance(lhs, Operator) and lhs.operator in ("u", "s"): + _eval_assign_inner(sim, lhs.operands[0], lhs_start, rhs, rhs_len) + elif isinstance(lhs, Signal): + lhs_stop = lhs_start + rhs_len + if lhs_stop > len(lhs): + lhs_stop = len(lhs) + if lhs_start >= len(lhs): + return + slot = sim.get_signal(lhs) + value = sim.slots[slot].next + mask = (1 << lhs_stop) - (1 << lhs_start) + value &= ~mask + value |= (rhs << lhs_start) & mask + value &= (1 << len(lhs)) - 1 + if lhs._signed and (value & (1 << (len(lhs) - 1))): + value |= -1 << (len(lhs) - 1) + sim.slots[slot].set(value) + elif isinstance(lhs, Slice): + _eval_assign_inner(sim, lhs.value, lhs_start + lhs.start, rhs, rhs_len) + elif isinstance(lhs, Concat): + part_stop = 0 + for part in lhs.parts: + part_start = part_stop + part_len = len(part) + part_stop = part_start + part_len + if lhs_start >= part_stop: + continue + if lhs_start + rhs_len <= part_start: + continue + if lhs_start < part_start: + part_lhs_start = 0 + part_rhs_start = part_start - lhs_start + else: + part_lhs_start = lhs_start - part_start + part_rhs_start = 0 + if lhs_start + rhs_len >= part_stop: + part_rhs_len = part_stop - lhs_start - part_rhs_start + else: + part_rhs_len = rhs_len - part_rhs_start + part_rhs = rhs >> part_rhs_start + part_rhs &= (1 << part_rhs_len) - 1 + _eval_assign_inner(sim, part, part_lhs_start, part_rhs, part_rhs_len) + elif isinstance(lhs, Part): + offset = eval_value(sim, lhs.offset) + offset *= lhs.stride + _eval_assign_inner(sim, lhs.value, lhs_start + offset, rhs, rhs_len) + elif isinstance(lhs, SwitchValue): + test = eval_value(sim, lhs.test) + for patterns, val in lhs.cases: + if _eval_matches(test, patterns): + _eval_assign_inner(sim, val, lhs_start, rhs, rhs_len) + return + else: + raise ValueError(f"Value {lhs!r} cannot be assigned") + +def eval_assign(sim, lhs, value): + _eval_assign_inner(sim, lhs, 0, value, len(lhs)) \ No newline at end of file diff --git a/tests/test_sim.py b/tests/test_sim.py index bb5d495..0f139b7 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -43,11 +43,30 @@ class SimulatorUnitTestCase(FHDLTestCase): with sim.write_vcd("test.vcd", "test.gtkw", traces=[*isigs, osig]): sim.run() + frag = Fragment() + sim = Simulator(frag) + def process(): + for isig, input in zip(isigs, inputs): + yield isig.eq(input) + yield Delay(0) + if isinstance(stmt, Assign): + yield stmt + else: + yield from stmt + yield Delay(0) + self.assertEqual((yield osig), output.value) + sim.add_testbench(process) + with sim.write_vcd("test.vcd", "test.gtkw", traces=[*isigs, osig]): + sim.run() + + def test_invert(self): stmt = lambda y, a: y.eq(~a) self.assertStatement(stmt, [C(0b0000, 4)], C(0b1111, 4)) self.assertStatement(stmt, [C(0b1010, 4)], C(0b0101, 4)) self.assertStatement(stmt, [C(0, 4)], C(-1, 4)) + self.assertStatement(stmt, [C(0b0000, signed(4))], C(-1, signed(4))) + self.assertStatement(stmt, [C(0b1010, signed(4))], C(0b0101, signed(4))) def test_neg(self): stmt = lambda y, a: y.eq(-a) @@ -126,6 +145,7 @@ class SimulatorUnitTestCase(FHDLTestCase): def test_floordiv(self): stmt = lambda y, a, b: y.eq(a // b) + self.assertStatement(stmt, [C(2, 4), C(0, 4)], C(0, 8)) self.assertStatement(stmt, [C(2, 4), C(1, 4)], C(2, 8)) self.assertStatement(stmt, [C(2, 4), C(2, 4)], C(1, 8)) self.assertStatement(stmt, [C(7, 4), C(2, 4)], C(3, 8)) @@ -285,6 +305,17 @@ class SimulatorUnitTestCase(FHDLTestCase): stmt = lambda y, a: [Cat(l, m, n).eq(a), y.eq(Cat(n, m, l))] self.assertStatement(stmt, [C(0b100101110, 9)], C(0b110101100, 9)) + def test_cat_slice_lhs(self): + l = Signal(3) + m = Signal(3) + n = Signal(3) + o = Signal(3) + p = Signal(3) + stmt = lambda y, a: [Cat(l, m, n, o, p).eq(-1), Cat(l, m, n, o, p)[4:11].eq(a), y.eq(Cat(p, o, n, m, l))] + self.assertStatement(stmt, [C(0b0000000, 7)], C(0b111001000100111, 15)) + self.assertStatement(stmt, [C(0b1001011, 7)], C(0b111111010110111, 15)) + self.assertStatement(stmt, [C(0b1111111, 7)], C(0b111111111111111, 15)) + def test_nested_cat_lhs(self): l = Signal(3) m = Signal(3) @@ -327,6 +358,16 @@ class SimulatorUnitTestCase(FHDLTestCase): self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001)) self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001)) + def test_array_lhs_heterogenous(self): + l = Signal(1, init=1) + m = Signal(3, init=4) + n = Signal(5, init=7) + array = Array([l, m, n]) + stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))] + self.assertStatement(stmt, [C(0), C(0b000)], C(0b001111000, 9)) + self.assertStatement(stmt, [C(1), C(0b010)], C(0b001110101, 9)) + self.assertStatement(stmt, [C(2), C(0b100)], C(0b001001001, 9)) + def test_array_lhs_oob(self): l = Signal(3) m = Signal(3)