From ccf7aaf00db54c7647b2f0f0cfdf34835c16fa8f Mon Sep 17 00:00:00 2001 From: Wanda Date: Thu, 5 Oct 2023 13:53:56 +0200 Subject: [PATCH] sim._pyrtl: fix masking for bitwise operands and muxes. Fixes #926. --- amaranth/sim/_pyrtl.py | 8 ++++---- tests/test_sim.py | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index 0fa23af..cd78c6f 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -170,11 +170,11 @@ class _RHSValueCompiler(_ValueCompiler): if value.operator == "%": return f"zmod({sign(lhs)}, {sign(rhs)})" if value.operator == "&": - return f"({mask(lhs)} & {mask(rhs)})" + return f"({sign(lhs)} & {sign(rhs)})" if value.operator == "|": - return f"({mask(lhs)} | {mask(rhs)})" + return f"({sign(lhs)} | {sign(rhs)})" if value.operator == "^": - return f"({mask(lhs)} ^ {mask(rhs)})" + return f"({sign(lhs)} ^ {sign(rhs)})" if value.operator == "<<": return f"({sign(lhs)} << {sign(rhs)})" if value.operator == ">>": @@ -194,7 +194,7 @@ class _RHSValueCompiler(_ValueCompiler): elif len(value.operands) == 3: if value.operator == "m": sel, val1, val0 = value.operands - return f"({self(val1)} if {mask(sel)} else {self(val0)})" + return f"({sign(val1)} if {mask(sel)} else {sign(val0)})" raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov: def on_Slice(self, value): diff --git a/tests/test_sim.py b/tests/test_sim.py index 1421609..60fd58e 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -152,6 +152,9 @@ class SimulatorUnitTestCase(FHDLTestCase): def test_and(self): stmt = lambda y, a, b: y.eq(a & b) self.assertStatement(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b1000, 4)) + self.assertStatement(stmt, [C(0b1010, 4), C(0b10, signed(2))], C(0b1010, 4)) + stmt = lambda y, a: y.eq(a) + self.assertStatement(stmt, [C(0b1010, 4) & C(-2, 2).as_unsigned()], C(0b0010, 4)) def test_or(self): stmt = lambda y, a, b: y.eq(a | b) @@ -211,6 +214,9 @@ class SimulatorUnitTestCase(FHDLTestCase): stmt = lambda y, a, b, c: y.eq(Mux(c, a, b)) self.assertStatement(stmt, [C(2, 4), C(3, 4), C(0)], C(3, 4)) self.assertStatement(stmt, [C(2, 4), C(3, 4), C(1)], C(2, 4)) + stmt = lambda y, a: y.eq(a) + self.assertStatement(stmt, [Mux(0, C(0b1010, 4), C(0b10, 2).as_signed())], C(0b1110, 4)) + self.assertStatement(stmt, [Mux(0, C(0b1010, 4), C(-2, 2).as_unsigned())], C(0b0010, 4)) def test_mux_invert(self): stmt = lambda y, a, b, c: y.eq(Mux(~c, a, b))